Skip to content

Commit

Permalink
Add flag making boundary forcing optional in models
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Dec 3, 2024
1 parent 9edfec3 commit 6e1c53c
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 34 deletions.
27 changes: 19 additions & 8 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ def __init__(
num_past_forcing_steps = args.num_past_forcing_steps
num_future_forcing_steps = args.num_future_forcing_steps

# TODO: Set based on existing of boundary forcing datastore
# TODO: Adjust what is stored here based on self.boundary_forced
self.boundary_forced = False

# Set up boundary mask
boundary_mask = torch.tensor(
da_boundary_mask.values, dtype=torch.float32
Expand Down Expand Up @@ -125,12 +129,6 @@ def __init__(
self.num_grid_nodes,
grid_static_dim,
) = self.grid_static_features.shape

(
self.num_boundary_nodes,
boundary_static_dim, # TODO Will need for computation below
) = self.boundary_static_features.shape
self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes
self.grid_dim = (
2 * self.grid_output_dim
+ grid_static_dim
Expand All @@ -139,7 +137,16 @@ def __init__(
* num_forcing_vars
* (num_past_forcing_steps + num_future_forcing_steps + 1)
)
self.boundary_dim = self.grid_dim # TODO Compute separately
if self.boundary_forced:
self.boundary_dim = self.grid_dim # TODO Compute separately
(
self.num_boundary_nodes,
boundary_static_dim, # TODO Will need for computation below
) = self.boundary_static_features.shape
self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes
else:
# Only interior grid nodes
self.num_input_nodes = self.num_grid_nodes

# Instantiate loss function
self.loss = metrics.get_metric(args.loss)
Expand Down Expand Up @@ -241,7 +248,11 @@ def unroll_prediction(self, init_states, forcing, boundary_forcing):

for i in range(pred_steps):
forcing_step = forcing[:, i]
boundary_forcing_step = boundary_forcing[:, i]

if self.boundary_forced:
boundary_forcing_step = boundary_forcing[:, i]
else:
boundary_forcing_step = None

pred_state, pred_std = self.predict_step(
prev_state, prev_prev_state, forcing_step, boundary_forcing_step
Expand Down
66 changes: 40 additions & 26 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,18 +47,22 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
self.grid_embedder = utils.make_mlp(
[self.grid_dim] + self.mlp_blueprint_end
)
# Optional separate embedder for boundary nodes
if args.shared_grid_embedder:
assert self.grid_dim == self.boundary_dim, (
"Grid and boundary input dimension must be the same when using "
f"the same embedder, got grid_dim={self.grid_dim}, "
f"boundary_dim={self.boundary_dim}"
)
self.boundary_embedder = self.grid_embedder
else:
self.boundary_embedder = utils.make_mlp(
[self.boundary_dim] + self.mlp_blueprint_end
)

if self.boundary_forced:
# Define embedder for boundary nodes
# Optional separate embedder for boundary nodes
if args.shared_grid_embedder:
assert self.grid_dim == self.boundary_dim, (
"Grid and boundary input dimension must "
"be the same when using "
f"the same embedder, got grid_dim={self.grid_dim}, "
f"boundary_dim={self.boundary_dim}"
)
self.boundary_embedder = self.grid_embedder
else:
self.boundary_embedder = utils.make_mlp(
[self.boundary_dim] + self.mlp_blueprint_end
)

self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end)
self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end)
Expand Down Expand Up @@ -136,27 +140,37 @@ def predict_step(
),
dim=-1,
)
# Create full boundary node features of shape
# (B, num_boundary_nodes, boundary_dim)
boundary_features = torch.cat(
(
boundary_forcing,
self.expand_to_batch(self.boundary_static_features, batch_size),
),
dim=-1,
)

if self.boundary_forced:
# Create full boundary node features of shape
# (B, num_boundary_nodes, boundary_dim)
boundary_features = torch.cat(
(
boundary_forcing,
self.expand_to_batch(
self.boundary_static_features, batch_size
),
),
dim=-1,
)

# Embed boundary features
boundary_emb = self.boundary_embedder(boundary_features)
# (B, num_boundary_nodes, d_h)

# Embed all features
grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h)
boundary_emb = self.boundary_embedder(boundary_features)
# (B, num_boundary_nodes, d_h)
g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h)
m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h)
mesh_emb = self.embedd_mesh_nodes()

# Merge interior and boundary emb into input embedding
# We enforce ordering (interior, boundary) of nodes
input_emb = torch.cat((grid_emb, boundary_emb), dim=1)
if self.boundary_forced:
# Merge interior and boundary emb into input embedding
# We enforce ordering (interior, boundary) of nodes
input_emb = torch.cat((grid_emb, boundary_emb), dim=1)
else:
# Only maps from interior to mesh
input_emb = grid_emb

# Map from grid to mesh
mesh_emb_expanded = self.expand_to_batch(
Expand Down

0 comments on commit 6e1c53c

Please sign in to comment.