Skip to content

Commit

Permalink
Add optional separate grid embedder for boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 13, 2024
1 parent 65347b9 commit 3b2000e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 8 deletions.
2 changes: 1 addition & 1 deletion neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(self, args):
) = self.grid_static_features.shape
(
self.num_boundary_nodes,
boundary_static_dim, # TODO Need for computation below
boundary_static_dim, # TODO Will need for computation below
) = self.boundary_static_features.shape
self.num_input_nodes = self.num_grid_nodes + self.num_boundary_nodes
self.grid_dim = (
Expand Down
22 changes: 15 additions & 7 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,23 @@ def __init__(self, args):
# Define sub-models
# Feature embedders for grid
self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1)
# TODO Optional separate embedder for boundary nodes
assert self.grid_dim == self.boundary_dim, (
"Grid and boundary input dimension must be the same when using "
f"the same encoder, got grid_dim={self.grid_dim}, "
f"boundary_dim={self.boundary_dim}"
)
self.grid_embedder = utils.make_mlp(
[self.grid_dim] + self.mlp_blueprint_end
)
# Optional separate embedder for boundary nodes
print(args.shared_grid_embedder)
if args.shared_grid_embedder:
assert self.grid_dim == self.boundary_dim, (
"Grid and boundary input dimension must be the same when using "
f"the same embedder, got grid_dim={self.grid_dim}, "
f"boundary_dim={self.boundary_dim}"
)
self.boundary_embedder = self.grid_embedder
else:
self.boundary_embedder = utils.make_mlp(
[self.boundary_dim] + self.mlp_blueprint_end
)

self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end)
self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end)

Expand Down Expand Up @@ -136,7 +144,7 @@ def predict_step(

# Embed all features
grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h)
boundary_emb = self.grid_embedder(boundary_features)
boundary_emb = self.boundary_embedder(boundary_features)
# (B, num_boundary_nodes, d_h)
g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h)
m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h)
Expand Down
7 changes: 7 additions & 0 deletions neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,13 @@ def main(input_args=None):
"output dimensions "
"(default: False (no))",
)
parser.add_argument(
"--shared_grid_embedder",
action="store_true", # Default to separate embedders
help="If the same embedder MLP should be used for interior and boundary"
" grid nodes. Note that this requires the same dimensionality for "
"both kinds of grid inputs. (default: False (no))",
)

# Training options
parser.add_argument(
Expand Down

0 comments on commit 3b2000e

Please sign in to comment.