diff --git a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py index 6a74abb..fb27467 100644 --- a/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py +++ b/neural_lam/datastore/npyfilesmeps/compute_standardization_stats.py @@ -167,8 +167,8 @@ def main( ) torch.cuda.set_device(device) if torch.cuda.is_available() else None - # Setting this to the original value of the original Oskarsson et al. paper - # (2023) -> 65 forecast steps - 2 initial steps = 63 + # Setting this to the original value of the Oskarsson et al. paper (2023) + # 65 forecast steps - 2 initial steps = 63 ar_steps = 63 ds = WeatherDataset( datastore=datastore, diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index ef8a010..a2bc59f 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -173,10 +173,10 @@ def _slice_state_time(self, da_state, idx, n_steps: int): # The current implementation requires at least 2 time steps for the # initial state (see GraphCast). init_steps = 2 - start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) - end_idx = idx + max(init_steps, self.num_past_forcing_steps) + n_steps # slice the dataarray to include the required number of time steps if self.datastore.is_forecast: + start_idx = max(0, self.num_past_forcing_steps - init_steps) + end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps # this implies that the data will have both `analysis_time` and # `elapsed_forecast_duration` dimensions for forecasts. We for now # simply select a analysis time and the first `n_steps` forecast @@ -198,6 +198,10 @@ def _slice_state_time(self, da_state, idx, n_steps: int): # For analysis data we slice the time dimension directly. The offset # is only relevant for the very first (and last) samples in the # dataset. + start_idx = idx + max(0, self.num_past_forcing_steps - init_steps) + end_idx = ( + idx + max(init_steps, self.num_past_forcing_steps) + n_steps + ) da_sliced = da_state.isel(time=slice(start_idx, end_idx)) return da_sliced @@ -235,7 +239,6 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # as past forcings. init_steps = 2 da_list = [] - offset = idx + max(init_steps, self.num_past_forcing_steps) if self.datastore.is_forecast: # This implies that the data will have both `analysis_time` and @@ -244,6 +247,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # times (given no offset). Note that this means that we get one # sample per forecast. # Add a 'time' dimension using the actual forecast times + offset = max(init_steps, self.num_past_forcing_steps) for step in range(n_steps): start_idx = offset + step - self.num_past_forcing_steps end_idx = offset + step + self.num_future_forcing_steps @@ -280,6 +284,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int): # For analysis data, we slice the time dimension directly. The # offset is only relevant for the very first (and last) samples in # the dataset. + offset = idx + max(init_steps, self.num_past_forcing_steps) for step in range(n_steps): start_idx = offset + step - self.num_past_forcing_steps end_idx = offset + step + self.num_future_forcing_steps