Skip to content

Commit

Permalink
Linting and bugfixes
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 28, 2024
1 parent 24bb665 commit bc21b73
Showing 4 changed files with 16 additions and 12 deletions.
8 changes: 7 additions & 1 deletion neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
@@ -278,7 +278,13 @@ def common_step(self, batch):
(B, pred_steps, num_boundary_nodes, d_boundary_forcing),
where index 0 corresponds to index 1 of init_states
"""
(init_states, target_states, forcing_features, batch_times) = batch
(
init_states,
target_states,
forcing,
boundary_forcing,
batch_times,
) = batch

prediction, pred_std = self.unroll_prediction(
init_states, forcing, boundary_forcing
6 changes: 2 additions & 4 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
@@ -295,13 +295,11 @@ def get_reordered_grid_pos(datastore):
"""
Interior nodes first, then boundary
"""
xy_np = datastore.get_xy() # np, (num_grid, 2)
xy_np = datastore.get_xy() # np, (num_grid, 2)
xy_torch = torch.tensor(xy_np, dtype=torch.float32)

da_boundary_mask = datastore.boundary_mask
boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.bool
)
boundary_mask = torch.tensor(da_boundary_mask.values, dtype=torch.bool)
interior_mask = torch.logical_not(boundary_mask)

return torch.cat(
3 changes: 1 addition & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
@@ -129,7 +129,7 @@ def plot_prediction(

# Plot pred and target
for ax, da in zip(axes, (da_target, da_prediction)):
im = plot_on_axis(
plot_on_axis(
ax,
da,
datastore,
@@ -181,7 +181,6 @@ def plot_spatial_error(
error_grid,
origin="lower",
extent=extent,
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
cmap="OrRd",
11 changes: 6 additions & 5 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
@@ -116,11 +116,11 @@ def __init__(

# Load border/interior mask for splitting
border_mask_float = torch.tensor(
self.datastore.boundary_mask, dtype=torch.float32)
self.datastore.boundary_mask, dtype=torch.float32
)
self.border_mask = border_mask_float.to(torch.bool)[:, 0]
self.interior_mask = torch.logical_not(self.border_mask)


def __len__(self):
if self.datastore.is_forecast:
# for now we simply create a single sample for each analysis time
@@ -509,8 +509,9 @@ def __getitem__(self, idx):
# To match current setup, allowing for same grid encoder also for
# boundary, boundary forcing should contain (in order) prev_state,
# prev_prev_state, forcing (and later added on static features).
boundary_forcing_sample = torch.cat((
init_states_boundary, targets_boundary), dim=0)
boundary_forcing_sample = torch.cat(
(init_states_boundary, targets_boundary), dim=0
)

boundary_forcing = torch.cat(
(
@@ -532,7 +533,7 @@ def __getitem__(self, idx):
target_states,
forcing,
boundary_forcing,
target_times
target_times,
)

def __iter__(self):

0 comments on commit bc21b73

Please sign in to comment.