From 6e1c53ca70678c559d4a37324221105beb799cea Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 3 Dec 2024 11:35:16 +0100 Subject: [PATCH] Add flag making boundary forcing optional in models --- neural_lam/models/ar_model.py | 27 +++++++---- neural_lam/models/base_graph_model.py | 66 ++++++++++++++++----------- 2 files changed, 59 insertions(+), 34 deletions(-) diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index ceadb85..ef76611 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -48,6 +48,10 @@ def __init__( num_past_forcing_steps = args.num_past_forcing_steps num_future_forcing_steps = args.num_future_forcing_steps + # TODO: Set based on existing of boundary forcing datastore + # TODO: Adjust what is stored here based on self.boundary_forced + self.boundary_forced = False + # Set up boundary mask boundary_mask = torch.tensor( da_boundary_mask.values, dtype=torch.float32 @@ -125,12 +129,6 @@ def __init__( self.num_grid_nodes, grid_static_dim, ) = self.grid_static_features.shape - - ( - self.num_boundary_nodes, - boundary_static_dim, # TODO Will need for computation below - ) = self.boundary_static_features.shape - self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes self.grid_dim = ( 2 * self.grid_output_dim + grid_static_dim @@ -139,7 +137,16 @@ def __init__( * num_forcing_vars * (num_past_forcing_steps + num_future_forcing_steps + 1) ) - self.boundary_dim = self.grid_dim # TODO Compute separately + if self.boundary_forced: + self.boundary_dim = self.grid_dim # TODO Compute separately + ( + self.num_boundary_nodes, + boundary_static_dim, # TODO Will need for computation below + ) = self.boundary_static_features.shape + self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes + else: + # Only interior grid nodes + self.num_input_nodes = self.num_grid_nodes # Instantiate loss function self.loss = metrics.get_metric(args.loss) @@ -241,7 +248,11 @@ def unroll_prediction(self, init_states, forcing, boundary_forcing): for i in range(pred_steps): forcing_step = forcing[:, i] - boundary_forcing_step = boundary_forcing[:, i] + + if self.boundary_forced: + boundary_forcing_step = boundary_forcing[:, i] + else: + boundary_forcing_step = None pred_state, pred_std = self.predict_step( prev_state, prev_prev_state, forcing_step, boundary_forcing_step diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 52f2d7a..61c1a68 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -47,18 +47,22 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore): self.grid_embedder = utils.make_mlp( [self.grid_dim] + self.mlp_blueprint_end ) - # Optional separate embedder for boundary nodes - if args.shared_grid_embedder: - assert self.grid_dim == self.boundary_dim, ( - "Grid and boundary input dimension must be the same when using " - f"the same embedder, got grid_dim={self.grid_dim}, " - f"boundary_dim={self.boundary_dim}" - ) - self.boundary_embedder = self.grid_embedder - else: - self.boundary_embedder = utils.make_mlp( - [self.boundary_dim] + self.mlp_blueprint_end - ) + + if self.boundary_forced: + # Define embedder for boundary nodes + # Optional separate embedder for boundary nodes + if args.shared_grid_embedder: + assert self.grid_dim == self.boundary_dim, ( + "Grid and boundary input dimension must " + "be the same when using " + f"the same embedder, got grid_dim={self.grid_dim}, " + f"boundary_dim={self.boundary_dim}" + ) + self.boundary_embedder = self.grid_embedder + else: + self.boundary_embedder = utils.make_mlp( + [self.boundary_dim] + self.mlp_blueprint_end + ) self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) @@ -136,27 +140,37 @@ def predict_step( ), dim=-1, ) - # Create full boundary node features of shape - # (B, num_boundary_nodes, boundary_dim) - boundary_features = torch.cat( - ( - boundary_forcing, - self.expand_to_batch(self.boundary_static_features, batch_size), - ), - dim=-1, - ) + + if self.boundary_forced: + # Create full boundary node features of shape + # (B, num_boundary_nodes, boundary_dim) + boundary_features = torch.cat( + ( + boundary_forcing, + self.expand_to_batch( + self.boundary_static_features, batch_size + ), + ), + dim=-1, + ) + + # Embed boundary features + boundary_emb = self.boundary_embedder(boundary_features) + # (B, num_boundary_nodes, d_h) # Embed all features grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) - boundary_emb = self.boundary_embedder(boundary_features) - # (B, num_boundary_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() - # Merge interior and boundary emb into input embedding - # We enforce ordering (interior, boundary) of nodes - input_emb = torch.cat((grid_emb, boundary_emb), dim=1) + if self.boundary_forced: + # Merge interior and boundary emb into input embedding + # We enforce ordering (interior, boundary) of nodes + input_emb = torch.cat((grid_emb, boundary_emb), dim=1) + else: + # Only maps from interior to mesh + input_emb = grid_emb # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch(