diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 426ee56..f48da66 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -278,7 +278,13 @@ def common_step(self, batch): (B, pred_steps, num_boundary_nodes, d_boundary_forcing), where index 0 corresponds to index 1 of init_states """ - (init_states, target_states, forcing_features, batch_times) = batch + ( + init_states, + target_states, + forcing, + boundary_forcing, + batch_times, + ) = batch prediction, pred_std = self.unroll_prediction( init_states, forcing, boundary_forcing diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 6241c1c..8e43fa4 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -295,13 +295,11 @@ def get_reordered_grid_pos(datastore): """ Interior nodes first, then boundary """ - xy_np = datastore.get_xy() # np, (num_grid, 2) + xy_np = datastore.get_xy() # np, (num_grid, 2) xy_torch = torch.tensor(xy_np, dtype=torch.float32) da_boundary_mask = datastore.boundary_mask - boundary_mask = torch.tensor( - da_boundary_mask.values, dtype=torch.bool - ) + boundary_mask = torch.tensor(da_boundary_mask.values, dtype=torch.bool) interior_mask = torch.logical_not(boundary_mask) return torch.cat( diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 455dd26..0e9cccd 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -129,7 +129,7 @@ def plot_prediction( # Plot pred and target for ax, da in zip(axes, (da_target, da_prediction)): - im = plot_on_axis( + plot_on_axis( ax, da, datastore, @@ -181,7 +181,6 @@ def plot_spatial_error( error_grid, origin="lower", extent=extent, - alpha=pixel_alpha, vmin=vmin, vmax=vmax, cmap="OrRd", diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 9dc8ca9..27be3e7 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -116,11 +116,11 @@ def __init__( # Load border/interior mask for splitting border_mask_float = torch.tensor( - self.datastore.boundary_mask, dtype=torch.float32) + self.datastore.boundary_mask, dtype=torch.float32 + ) self.border_mask = border_mask_float.to(torch.bool)[:, 0] self.interior_mask = torch.logical_not(self.border_mask) - def __len__(self): if self.datastore.is_forecast: # for now we simply create a single sample for each analysis time @@ -509,8 +509,9 @@ def __getitem__(self, idx): # To match current setup, allowing for same grid encoder also for # boundary, boundary forcing should contain (in order) prev_state, # prev_prev_state, forcing (and later added on static features). - boundary_forcing_sample = torch.cat(( - init_states_boundary, targets_boundary), dim=0) + boundary_forcing_sample = torch.cat( + (init_states_boundary, targets_boundary), dim=0 + ) boundary_forcing = torch.cat( ( @@ -532,7 +533,7 @@ def __getitem__(self, idx): target_states, forcing, boundary_forcing, - target_times + target_times, ) def __iter__(self):