Skip to content

Commit

Permalink
Bugfix, idx removed from forecast forcing window indices
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed Nov 13, 2024
1 parent 38cdfe6 commit cd53b21
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -167,8 +167,8 @@ def main(
)
torch.cuda.set_device(device) if torch.cuda.is_available() else None

# Setting this to the original value of the original Oskarsson et al. paper
# (2023) -> 65 forecast steps - 2 initial steps = 63
# Setting this to the original value of the Oskarsson et al. paper (2023)
# 65 forecast steps - 2 initial steps = 63
ar_steps = 63
ds = WeatherDataset(
datastore=datastore,
Expand Down
11 changes: 8 additions & 3 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,10 @@ def _slice_state_time(self, da_state, idx, n_steps: int):
# The current implementation requires at least 2 time steps for the
# initial state (see GraphCast).
init_steps = 2
start_idx = idx + max(0, self.num_past_forcing_steps - init_steps)
end_idx = idx + max(init_steps, self.num_past_forcing_steps) + n_steps
# slice the dataarray to include the required number of time steps
if self.datastore.is_forecast:
start_idx = max(0, self.num_past_forcing_steps - init_steps)
end_idx = max(init_steps, self.num_past_forcing_steps) + n_steps
# this implies that the data will have both `analysis_time` and
# `elapsed_forecast_duration` dimensions for forecasts. We for now
# simply select a analysis time and the first `n_steps` forecast
Expand All @@ -198,6 +198,10 @@ def _slice_state_time(self, da_state, idx, n_steps: int):
# For analysis data we slice the time dimension directly. The offset
# is only relevant for the very first (and last) samples in the
# dataset.
start_idx = idx + max(0, self.num_past_forcing_steps - init_steps)
end_idx = (
idx + max(init_steps, self.num_past_forcing_steps) + n_steps
)
da_sliced = da_state.isel(time=slice(start_idx, end_idx))
return da_sliced

Expand Down Expand Up @@ -235,7 +239,6 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
# as past forcings.
init_steps = 2
da_list = []
offset = idx + max(init_steps, self.num_past_forcing_steps)

if self.datastore.is_forecast:
# This implies that the data will have both `analysis_time` and
Expand All @@ -244,6 +247,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
# times (given no offset). Note that this means that we get one
# sample per forecast.
# Add a 'time' dimension using the actual forecast times
offset = max(init_steps, self.num_past_forcing_steps)
for step in range(n_steps):
start_idx = offset + step - self.num_past_forcing_steps
end_idx = offset + step + self.num_future_forcing_steps
Expand Down Expand Up @@ -280,6 +284,7 @@ def _slice_forcing_time(self, da_forcing, idx, n_steps: int):
# For analysis data, we slice the time dimension directly. The
# offset is only relevant for the very first (and last) samples in
# the dataset.
offset = idx + max(init_steps, self.num_past_forcing_steps)
for step in range(n_steps):
start_idx = offset + step - self.num_past_forcing_steps
end_idx = offset + step + self.num_future_forcing_steps
Expand Down

0 comments on commit cd53b21

Please sign in to comment.