Skip to content

Commit

Permalink
ensure dimension order from BaseRegularGridDatastore.stack_grid_coords
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Nov 12, 2024
1 parent ff02af7 commit 490f1e3
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 15 deletions.
17 changes: 16 additions & 1 deletion neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,22 @@ def stack_grid_coords(
if "grid_index" in da_or_ds.dims:
return da_or_ds

return da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
da_or_ds_stacked = da_or_ds.stack(grid_index=self.CARTESIAN_COORDS)
# find the feature dimension, which has named with the format
# `{category}_feature`
potential_feature_dims = [
d for d in da_or_ds_stacked.dims if d.endswith("_feature")
]
if not len(potential_feature_dims) == 1:
raise ValueError(
"Expected exactly one feature dimension in the stacked data, "
f"got {potential_feature_dims}"
)
feature_dim = potential_feature_dims[0]

# ensure that grid_index is the first dimension, and the feature
# dimension is the second
return da_or_ds_stacked.transpose("grid_index", feature_dim, ...)

@property
@functools.lru_cache
Expand Down
26 changes: 18 additions & 8 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,14 +64,24 @@ def download_meps_example_reduced_dataset():
with open(config_path, "w") as f:
yaml.dump(config, f)

# create parameters
compute_standardization_stats_meps.main(
datastore_config_path=config_path,
batch_size=8,
step_length=3,
n_workers=1,
distributed=False,
)
# create parameters, only run if the files we expect are not present
expected_parameter_files = [
"parameter_mean.pt",
"parameter_std.pt",
"diff_mean.pt",
"diff_std.pt",
]
expected_parameter_filepaths = [
dataset_path / "static" / fn for fn in expected_parameter_files
]
if any(not p.exists() for p in expected_parameter_filepaths):
compute_standardization_stats_meps.main(
datastore_config_path=config_path,
batch_size=8,
step_length=3,
n_workers=1,
distributed=False,
)

return config_path

Expand Down
10 changes: 4 additions & 6 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,22 +310,20 @@ def get_grid_shape_state(datastore_name):


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_stacking_grid_coords(datastore_name):
@pytest.mark.parametrize("category", ["state", "forcing", "static"])
def test_stacking_grid_coords(datastore_name, category):
"""Check that the `datastore.stack_grid_coords` method is implemented."""
datastore = init_datastore_example(datastore_name)

if not isinstance(datastore, BaseRegularGridDatastore):
pytest.skip("Datastore does not implement `BaseCartesianDatastore`")

da_static = datastore.get_dataarray("static", split=None)
da_static = datastore.get_dataarray(category=category, split=None)

da_static_unstacked = datastore.unstack_grid_coords(da_static).load()
da_static_test = datastore.stack_grid_coords(da_static_unstacked)

# XXX: for the moment unstacking doesn't guarantee the order of the
# dimensions maybe we should enforce this?
da_static_test = da_static_test.transpose(*da_static.dims)

assert da_static.dims == da_static_test.dims
xr.testing.assert_equal(da_static, da_static_test)


Expand Down

0 comments on commit 490f1e3

Please sign in to comment.