Skip to content

Commit

Permalink
Fix graph loading and boundary mask
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Dec 2, 2024
1 parent ac9c69d commit b35072d
Show file tree
Hide file tree
Showing 6 changed files with 16 additions and 12 deletions.
10 changes: 6 additions & 4 deletions neural_lam/build_rectangular_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(input_args=None):
help="Path to the configuration for neural-lam",
)
parser.add_argument(
"--name",
"--graph_name",
type=str,
help="Name to save graph as (default: multiscale)",
)
Expand Down Expand Up @@ -74,8 +74,8 @@ def main(input_args=None):
args.config_path is not None
), "Specify your config with --config_path"
assert (
args.name is not None
), "Specify the name to save graph as with --name"
args.graph_name is not None
), "Specify the name to save graph as with --graph_name"

_, datastore = load_config_and_datastore(config_path=args.config_path)

Expand Down Expand Up @@ -124,7 +124,9 @@ def main(input_args=None):
print(f"{name}: {subgraph}")

# Save graph
graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name)
graph_dir_path = os.path.join(
datastore.root_path, "graphs", args.graph_name
)
os.makedirs(graph_dir_path, exist_ok=True)
for component, graph in graph_comp.items():
# This seems like a bit of a hack, maybe better if saving in wmg
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def __init__(
static_features_torch = torch.tensor(arr_static, dtype=torch.float32)
self.register_buffer(
"grid_static_features",
static_features_torch[self.boundary_mask.to(torch.bool)],
static_features_torch[self.boundary_mask[:, 0].to(torch.bool)],
persistent=False,
)
self.register_buffer(
"boundary_static_features",
static_features_torch[self.interior_mask.to(torch.bool)],
static_features_torch[self.interior_mask[:, 0].to(torch.bool)],
persistent=False,
)

Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
super().__init__(args, config=config, datastore=datastore)

# Load graph with static features
graph_dir_path = datastore.root_path / "graphs" / args.graph
graph_dir_path = datastore.root_path / "graphs" / args.graph_name
self.hierarchical, graph_ldict = utils.load_graph(
graph_dir_path=graph_dir_path
)
Expand Down
6 changes: 4 additions & 2 deletions neural_lam/plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def main():
help="Path to the configuration for neural-lam",
)
parser.add_argument(
"--name",
"--graph_name",
type=str,
default="multiscale",
help="Name of saved graph to plot (default: multiscale)",
Expand All @@ -46,7 +46,9 @@ def main():
_, datastore = load_config_and_datastore(config_path=args.config_path)

# Load graph data
graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name)
graph_dir_path = os.path.join(
datastore.root_path, "graphs", args.graph_name
)
hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path)
(g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def main(input_args=None):

# Model architecture
parser.add_argument(
"--graph",
"--graph_name",
type=str,
default="multiscale",
help="Graph to load and use in graph-based model "
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,9 @@ def __init__(

# Load border/interior mask for splitting
border_mask_float = torch.tensor(
self.datastore.boundary_mask, dtype=torch.float32
self.datastore.boundary_mask.to_numpy(), dtype=torch.float32
)
self.border_mask = border_mask_float.to(torch.bool)[:, 0]
self.border_mask = border_mask_float.to(torch.bool)
self.interior_mask = torch.logical_not(self.border_mask)

def __len__(self):
Expand Down

0 comments on commit b35072d

Please sign in to comment.