Skip to content

Commit

Permalink
Merge branch 'feat/datastores' of github.com:leifdenby/neural-lam int…
Browse files Browse the repository at this point in the history
…o feat/datastores
  • Loading branch information
sadamov committed Nov 13, 2024
2 parents b56e47a + a95eb5a commit 258079c
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 36 deletions.
4 changes: 3 additions & 1 deletion neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -590,7 +590,9 @@ def cli(input_args=None):
)
args = parser.parse_args(input_args)

assert args.config is not None, "Specify your config with --config_path"
assert (
args.config_path is not None
), "Specify your config with --config_path"

# Load neural-lam configuration and datastore to use
_, datastore = load_config_and_datastore(config_path=args.config_path)
Expand Down
78 changes: 63 additions & 15 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,53 @@ def state_feature_weights_values(self) -> List[float]:
"""
pass

@functools.lru_cache
def expected_dim_order(self, category: str = None) -> List[str]:
"""
Return the expected dimension order for the dataarray or dataset
returned by `get_dataarray` for the given category of data. The
dimension order is the order of the dimensions in the dataarray or
dataset, and is used to check that the data is in the expected format.
This is necessary so that when stacking and unstacking the spatial grid
we can ensure that the dimension order is the same as what is returned
from `get_dataarray`. And also ensures that downstream uses of a
datastore (e.g. WeatherDataset) sees the data in a common structure.
If the category is None, then the it assumed that data only represents
a 1D scalar field varying with grid-index.
Parameters
----------
category : str
The category of the dataset (state/forcing/static).
Returns
-------
List[str]
The expected dimension order for the dataarray or dataset.
"""
dim_order = ["grid_index"]

if category is not None:
dim_order.append(f"{category}_feature")

if category != "static":
# static data does not vary in time
if self.is_forecast:
dim_order.extend(
["analysis_time", "elapsed_forecast_duration"]
)
elif not self.is_forecast:
dim_order.append("time")

if self.is_ensemble and category == "state":
# XXX: for now we only assume ensemble data for state variables
dim_order.append("ensemble_member")

return dim_order


@dataclasses.dataclass
class CartesianGridShape:
Expand Down Expand Up @@ -464,21 +511,22 @@ def stack_grid_coords(
return da_or_ds

da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
# find the feature dimension, which has named with the format
# `{category}_feature`
potential_feature_dims = [
d for d in da_or_ds_stacked.dims if d.endswith("_feature")
]
if not len(potential_feature_dims) == 1:
raise ValueError(
"Expected exactly one feature dimension in the stacked data, "
f"got {potential_feature_dims}"
)
feature_dim = potential_feature_dims[0]

# ensure that grid_index is the first dimension, and the feature
# dimension is the second
return da_or_ds_stacked.transpose("grid_index", feature_dim, ...)

# infer what category of data the array represents by finding the
# dimension named in the format `{category}_feature`
category = None
for dim in da_or_ds_stacked.dims:
if dim.endswith("_feature"):
if category is not None:
raise ValueError(
"Multiple dimensions ending with '_feature' found in "
f"dataarray: {da_or_ds_stacked}. Cannot infer category."
)
category = dim.split("_")[0]

dim_order = self.expected_dim_order(category=category)

return da_or_ds_stacked.transpose(*dim_order)

@property
@functools.lru_cache
Expand Down
9 changes: 5 additions & 4 deletions neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,9 +266,7 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
# set multi-index for grid-index
da_category = da_category.set_index(grid_index=self.CARTESIAN_COORDS)

if "time" not in da_category.dims:
return da_category
else:
if "time" in da_category.dims:
t_start = (
self._ds.splits.sel(split_name=split)
.sel(split_part="start")
Expand All @@ -281,7 +279,10 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
.load()
.item()
)
return da_category.sel(time=slice(t_start, t_end))
da_category = da_category.sel(time=slice(t_start, t_end))

dim_order = self.expected_dim_order(category=category)
return da_category.transpose(*dim_order)

def get_standardization_dataarray(self, category: str) -> xr.Dataset:
"""
Expand Down
19 changes: 12 additions & 7 deletions neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,15 +282,16 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
features=features, split=split
)
das.append(da)
da = xr.concat(das, dim="feature").transpose(
"grid_index", "feature"
)
da = xr.concat(das, dim="feature")

else:
raise NotImplementedError(category)

da = da.rename(dict(feature=f"{category}_feature"))

# stack the [x, y] dimensions into a `grid_index` dimension
da = self.stack_grid_coords(da)

# check that we have the right features
actual_features = da[f"{category}_feature"].values.tolist()
expected_features = self.get_vars_names(category=category)
Expand All @@ -299,6 +300,9 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
f"Expected features {expected_features}, got {actual_features}"
)

dim_order = self.expected_dim_order(category=category)
da = da.transpose(*dim_order)

return da

def _get_single_timeseries_dataarray(
Expand Down Expand Up @@ -346,7 +350,11 @@ def _get_single_timeseries_dataarray(
None,
), "Unknown dataset split"
else:
assert split in ("train", "val", "test"), "Unknown dataset split"
assert split in (
"train",
"val",
"test",
), f"Unknown dataset split {split} for features {features}"

if member is not None and features != self.get_vars_names(
category="state"
Expand Down Expand Up @@ -495,9 +503,6 @@ def _get_single_timeseries_dataarray(

da = xr.DataArray(arr_all, dims=dims, coords=coords)

# stack the [x, y] dimensions into a `grid_index` dimension
da = self.stack_grid_coords(da)

return da

def _get_analysis_times(self, split) -> List[np.datetime64]:
Expand Down
4 changes: 3 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def main(input_args=None):
}

# Asserts for arguments
assert args.config is not None, "Specify your config with --config_path"
assert (
args.config_path is not None
), "Specify your config with --config_path"
assert args.model in MODELS, f"Unknown model: {args.model}"
assert args.eval in (
None,
Expand Down
7 changes: 2 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,8 @@

# Initializing variables for the s3 client
S3_BUCKET_NAME = "mllam-testdata"
# S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
S3_ENDPOINT_URL = "http://localhost:8000"
# S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip"
# TODO: I will upload this to AWS S3 once I have located the credentials...
S3_FILE_PATH = "meps_example_reduced.v0.2.0.zip"
S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int"
S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.2.0.zip"
S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH])
TEST_DATA_KNOWN_HASH = (
"7ff2e07e04cfcd77631115f800c9d49188bb2a7c2a2777da3cea219f926d0c86"
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def test_config_serialization(state_weighting_config):
training:
state_feature_weighting:
__config_class__: ManualStateFeatureWeighting
values:
weights:
u100m: 1.0
v100m: 1.0
"""
Expand Down
3 changes: 1 addition & 2 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def test_stacking_grid_coords(datastore_name, category):
if not isinstance(datastore, BaseRegularGridDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")

da_static = datastore.get_dataarray(category=category, split=None)
da_static = datastore.get_dataarray(category=category, split="train")

da_static_unstacked = datastore.unstack_grid_coords(da_static).load()
da_static_test = datastore.stack_grid_coords(da_static_unstacked)
Expand All @@ -339,7 +339,6 @@ def test_dataarray_shapes(datastore_name):
unstacked_tensor = torch.tensor(
datastore.unstack_grid_coords(static_da).to_numpy(), dtype=torch.float32
).squeeze()
print(static_da)

reshaped_tensor = (
torch.tensor(static_da.to_numpy(), dtype=torch.float32)
Expand Down

0 comments on commit 258079c

Please sign in to comment.