Skip to content

Commit

Permalink
check for consistency of num features across splits
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Jul 24, 2024
1 parent cfb0618 commit 8698719
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ def test_get_dataarray(datastore_name):
datastore = init_datastore(datastore_name)

for category in ["state", "forcing", "static"]:
n_features = {}
for split in ["train", "val", "test"]:
expected_dims = ["grid_index", f"{category}_feature"]
if category != "static":
Expand All @@ -172,6 +173,11 @@ def test_get_dataarray(datastore_name):
if isinstance(datastore, BaseCartesianDatastore):
assert da.grid_index.size == grid_shape.x * grid_shape.y

n_features[split] = da[category + "_feature"].size

# check that the number of features is the same for all splits
assert n_features["train"] == n_features["val"] == n_features["test"]


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_boundary_mask(datastore_name):
Expand Down

0 comments on commit 8698719

Please sign in to comment.