Skip to content

Commit

Permalink
Add test for dataset length using different configs
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 11, 2024
1 parent 74828fa commit 8cc6c3d
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.weather_dataset import WeatherDataset
from tests.conftest import init_datastore_example
from tests.dummy_datastore import DummyDatastore


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
Expand Down Expand Up @@ -221,3 +222,39 @@ def _create_graph():
batch_device = [part.to(device_name) for part in batch]
model_device.common_step(batch_device)
model_device.training_step(batch_device)


@pytest.mark.parametrize("dataset_config", [

This comment has been minimized.

Copy link
@leifdenby

leifdenby Nov 12, 2024

Member

wow @sadamov adding these tests as examples like this explains the functionality perfectly, thank you ❤ !

This comment has been minimized.

Copy link
@leifdenby

leifdenby Nov 12, 2024

Member

sorry, just realised you did this @joeloskarsson ! Still awesome 🌟 🌟 🌟

{"past": 0, "future": 0, "ar_steps": 1, "exp_len_reduction": 3},
{"past": 2, "future": 0, "ar_steps": 1, "exp_len_reduction": 3},
{"past": 0, "future": 2, "ar_steps": 1, "exp_len_reduction": 5},
{"past": 4, "future": 0, "ar_steps": 1, "exp_len_reduction": 5},
{"past": 0, "future": 0, "ar_steps": 5, "exp_len_reduction": 7},
{"past": 3, "future": 3, "ar_steps": 2, "exp_len_reduction": 8},
])
def test_dataset_length(dataset_config):
"""Check that correct number of samples can be extracted from the dataset,
given a specific configuration of forcing windowing and ar_steps.
"""
# Use dummy datastore of length 10 here, only want to test slicing
# in dataset class
ds_len = 10
datastore = DummyDatastore(n_timesteps=ds_len)

dataset = WeatherDataset(
datastore=datastore,
split="train",
ar_steps=dataset_config["ar_steps"],
include_past_forcing=dataset_config["past"],
include_future_forcing=dataset_config["future"],
)

# We expect dataset to contain this many samples
expected_len = ds_len - dataset_config["exp_len_reduction"]

# Check that datast has correct length
assert len(dataset) == expected_len

# Check that we can actually get last and first sample
dataset[0]
dataset[expected_len - 1]

0 comments on commit 8cc6c3d

Please sign in to comment.