Skip to content

Commit

Permalink
add test for state/forcing values from time-slicing
Browse files Browse the repository at this point in the history
  • Loading branch information
leifdenby committed Nov 14, 2024
1 parent 0f24924 commit acb8ffa
Showing 1 changed file with 139 additions and 0 deletions.
139 changes: 139 additions & 0 deletions tests/test_time_slicing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Third-party
import numpy as np
import pytest
import xarray as xr

# First-party
from neural_lam.datastore.base import BaseDatastore
from neural_lam.weather_dataset import WeatherDataset


class SinglePointDummyDatastore(BaseDatastore):
step_length = 1
config = None
coords_projection = None
num_grid_points = 1
root_path = None

def __init__(self, time_values, state_data, forcing_data, is_forecast):
self._time_values = np.array(time_values)
self._state_data = np.array(state_data)
self._forcing_data = np.array(forcing_data)
self.is_forecast = is_forecast

if is_forecast:
assert self._state_data.ndim == 2
else:
assert self._state_data.ndim == 1

def get_num_data_vars(self, category):
return 1

def get_dataarray(self, category, split):
if category == "state":
values = self._state_data
elif category == "forcing":
values = self._forcing_data
else:
raise NotImplementedError(category)

if self.is_forecast:
raise NotImplementedError()
else:
da = xr.DataArray(
values, dims=["time"], coords={"time": self._time_values}
)
# add `{category}_feature` and `grid_index` dimensions

da = da.expand_dims("grid_index")
da = da.expand_dims(f"{category}_feature")

return da

def get_standardization_dataarray(self, category):
raise NotImplementedError()

def get_xy(self, category):
raise NotImplementedError()

def get_vars_units(self, category):
raise NotImplementedError()

def get_vars_names(self, category):
raise NotImplementedError()

def get_vars_long_names(self, category):
raise NotImplementedError()


ANALYSIS_STATE_VALUES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
FORCING_VALUES = [10, 11, 12, 13, 14, 15, 16, 17, 18, 19]


@pytest.mark.parametrize(
"ar_steps,num_past_forcing_steps,num_future_forcing_steps",
[[3, 0, 0], [3, 1, 0], [3, 2, 0], [3, 3, 0]],
)
def test_time_slicing_analysis(
ar_steps, num_past_forcing_steps, num_future_forcing_steps
):
# state and forcing variables have only on dimension, `time`
time_values = np.datetime64("2020-01-01") + np.arange(
len(ANALYSIS_STATE_VALUES)
)
assert len(ANALYSIS_STATE_VALUES) == len(FORCING_VALUES) == len(time_values)

datastore = SinglePointDummyDatastore(
state_data=ANALYSIS_STATE_VALUES,
forcing_data=FORCING_VALUES,
time_values=time_values,
is_forecast=False,
)

dataset = WeatherDataset(
datastore=datastore,
ar_steps=ar_steps,
num_future_forcing_steps=num_past_forcing_steps,
num_past_forcing_steps=num_future_forcing_steps,
standardize=False,
)

sample = dataset[0]

init_states, target_states, forcing, _ = [
tensor.numpy() for tensor in sample
]

expected_init_states = [0, 1]
if ar_steps == 3:
expected_target_states = [2, 3, 4]
else:
raise NotImplementedError()

if num_past_forcing_steps == num_future_forcing_steps == 0:
expected_forcing_values = [[12], [13], [14]]
elif num_past_forcing_steps == 1 and num_future_forcing_steps == 0:
expected_forcing_values = [[11, 12], [12, 13], [13, 14]]
elif num_past_forcing_steps == 2 and num_future_forcing_steps == 0:
expected_forcing_values = [[10, 11, 12], [11, 12, 13], [12, 13, 14]]
elif num_past_forcing_steps == 3 and num_future_forcing_steps == 0:
raise Exception("No idea what this should be ...")
else:
raise NotImplementedError()

# init_states: (2, N_grid, d_features)
# target_states: (ar_steps, N_grid, d_features)
# forcing: (ar_steps, N_grid, d_windowed_forcing)
# target_times: (ar_steps,)
assert init_states.shape == (2, 1, 1)
assert init_states[:, 0, 0].tolist() == expected_init_states

assert target_states.shape == (3, 1, 1)
assert target_states[:, 0, 0].tolist() == expected_target_states

assert forcing.shape == (
3,
1,
1 + num_past_forcing_steps + num_future_forcing_steps,
)
np.testing.assert_equal(forcing[:, 0, :], np.array(expected_forcing_values))

0 comments on commit acb8ffa

Please sign in to comment.