From 490f1e37e12c71617f343571cbb0580e1ad88480 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Tue, 12 Nov 2024 16:17:53 +0100 Subject: [PATCH] ensure dimension order from BaseRegularGridDatastore.stack_grid_coords --- neural_lam/datastore/base.py | 17 ++++++++++++++++- tests/conftest.py | 26 ++++++++++++++++++-------- tests/test_datastores.py | 10 ++++------ 3 files changed, 38 insertions(+), 15 deletions(-) diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 83610a6..25e7d01 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -463,7 +463,22 @@ def stack_grid_coords( if "grid_index" in da_or_ds.dims: return da_or_ds - return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) + da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS) + # find the feature dimension, which has named with the format + # `{category}_feature` + potential_feature_dims = [ + d for d in da_or_ds_stacked.dims if d.endswith("_feature") + ] + if not len(potential_feature_dims) == 1: + raise ValueError( + "Expected exactly one feature dimension in the stacked data, " + f"got {potential_feature_dims}" + ) + feature_dim = potential_feature_dims[0] + + # ensure that grid_index is the first dimension, and the feature + # dimension is the second + return da_or_ds_stacked.transpose("grid_index", feature_dim, ...) @property @functools.lru_cache diff --git a/tests/conftest.py b/tests/conftest.py index 84f6fb2..3dfac91 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,14 +64,24 @@ def download_meps_example_reduced_dataset(): with open(config_path, "w") as f: yaml.dump(config, f) - # create parameters - compute_standardization_stats_meps.main( - datastore_config_path=config_path, - batch_size=8, - step_length=3, - n_workers=1, - distributed=False, - ) + # create parameters, only run if the files we expect are not present + expected_parameter_files = [ + "parameter_mean.pt", + "parameter_std.pt", + "diff_mean.pt", + "diff_std.pt", + ] + expected_parameter_filepaths = [ + dataset_path / "static" / fn for fn in expected_parameter_files + ] + if any(not p.exists() for p in expected_parameter_filepaths): + compute_standardization_stats_meps.main( + datastore_config_path=config_path, + batch_size=8, + step_length=3, + n_workers=1, + distributed=False, + ) return config_path diff --git a/tests/test_datastores.py b/tests/test_datastores.py index c0d69ec..00cd508 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -310,22 +310,20 @@ def get_grid_shape_state(datastore_name): @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) -def test_stacking_grid_coords(datastore_name): +@pytest.mark.parametrize("category", ["state", "forcing", "static"]) +def test_stacking_grid_coords(datastore_name, category): """Check that the `datastore.stack_grid_coords` method is implemented.""" datastore = init_datastore_example(datastore_name) if not isinstance(datastore, BaseRegularGridDatastore): pytest.skip("Datastore does not implement `BaseCartesianDatastore`") - da_static = datastore.get_dataarray("static", split=None) + da_static = datastore.get_dataarray(category=category, split=None) da_static_unstacked = datastore.unstack_grid_coords(da_static).load() da_static_test = datastore.stack_grid_coords(da_static_unstacked) - # XXX: for the moment unstacking doesn't guarantee the order of the - # dimensions maybe we should enforce this? - da_static_test = da_static_test.transpose(*da_static.dims) - + assert da_static.dims == da_static_test.dims xr.testing.assert_equal(da_static, da_static_test)