Skip to content

Commit

Permalink
Make graph creation and plotting work with datastores
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Dec 2, 2024
1 parent bc21b73 commit ac9c69d
Show file tree
Hide file tree
Showing 6 changed files with 57 additions and 39 deletions.
47 changes: 30 additions & 17 deletions neural_lam/build_rectangular_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import weather_model_graphs as wmg

# Local
from . import config, utils
from . import utils
from .config import load_config_and_datastore

WMG_ARCHETYPES = {
"keisler": wmg.create.archetype.create_keisler_graph,
Expand All @@ -24,10 +25,14 @@ def main(input_args=None):

# Inputs and outputs
parser.add_argument(
"--data_config",
"--config_path",
type=str,
default="neural_lam/data_config.yaml",
help="Path to data config file",
help="Path to the configuration for neural-lam",
)
parser.add_argument(
"--name",
type=str,
help="Name to save graph as (default: multiscale)",
)
parser.add_argument(
"--output_dir",
Expand Down Expand Up @@ -65,21 +70,28 @@ def main(input_args=None):
)
args = parser.parse_args(input_args)

# Load grid positions
config_loader = config.Config.from_file(args.data_config)
assert (
args.config_path is not None
), "Specify your config with --config_path"
assert (
args.name is not None
), "Specify the name to save graph as with --name"

_, datastore = load_config_and_datastore(config_path=args.config_path)

# Load grid positions
# TODO Do not get normalised positions
coords = utils.get_reordered_grid_pos(config_loader.dataset.name).numpy()
coords = utils.get_reordered_grid_pos(datastore).numpy()
# (num_nodes_full, 2)

# Construct mask
static_data = utils.load_static_data(config_loader.dataset.name)
num_full_grid = coords.shape[0]
num_boundary = datastore.boundary_mask.to_numpy().sum()
num_interior = num_full_grid - num_boundary
decode_mask = np.concatenate(
(
np.ones(static_data["grid_static_features"].shape[0], dtype=bool),
np.zeros(
static_data["boundary_static_features"].shape[0], dtype=bool
),
np.ones(num_interior, dtype=bool),
np.zeros(num_boundary, dtype=bool),
),
axis=0,
)
Expand Down Expand Up @@ -112,7 +124,8 @@ def main(input_args=None):
print(f"{name}: {subgraph}")

# Save graph
os.makedirs(args.output_dir, exist_ok=True)
graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name)
os.makedirs(graph_dir_path, exist_ok=True)
for component, graph in graph_comp.items():
# This seems like a bit of a hack, maybe better if saving in wmg
# was made consistent with nl
Expand All @@ -130,7 +143,7 @@ def main(input_args=None):
name="m2m",
list_from_attribute="level",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
output_directory=graph_dir_path,
)
else:
# up and down directions
Expand All @@ -139,22 +152,22 @@ def main(input_args=None):
name=f"mesh_{direction}",
list_from_attribute="levels",
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
output_directory=graph_dir_path,
)
else:
wmg.save.to_pyg(
graph=graph,
name=component,
list_from_attribute="dummy", # Note: Needed to output list
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
output_directory=graph_dir_path,
)
else:
wmg.save.to_pyg(
graph=graph,
name=component,
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
output_directory=graph_dir_path,
)


Expand Down
2 changes: 1 addition & 1 deletion neural_lam/datastore/mdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def grid_shape_state(self):
assert da_x.ndim == da_y.ndim == 1
return CartesianGridShape(x=da_x.size, y=da_y.size)

def get_xy(self, category: str, stacked: bool) -> ndarray:
def get_xy(self, category: str, stacked: bool = True) -> ndarray:
"""Return the x, y coordinates of the dataset.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/datastore/npyfilesmeps/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def get_vars_long_names(self, category: str) -> List[str]:
def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category=category))

def get_xy(self, category: str, stacked: bool) -> np.ndarray:
def get_xy(self, category: str, stacked: bool = True) -> np.ndarray:
"""Return the x, y coordinates of the dataset.
Parameters
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def __init__(self, args, config: NeuralLAMConfig, datastore: BaseDatastore):
super().__init__(args, config=config, datastore=datastore)

# Load graph with static features
graph_dir_path = datastore.root_path / "graph" / args.graph
graph_dir_path = datastore.root_path / "graphs" / args.graph
self.hierarchical, graph_ldict = utils.load_graph(
graph_dir_path=graph_dir_path
)
Expand Down
41 changes: 23 additions & 18 deletions neural_lam/plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,20 @@
from . import utils
from .config import load_config_and_datastore

MESH_HEIGHT = 0.1
MESH_LEVEL_DIST = 0.2
GRID_HEIGHT = 0


def main():
"""Plot graph structure in 3D using plotly."""
parser = ArgumentParser(description="Plot graph")
parser.add_argument(
"--datastore_config_path",
"--config_path",
type=str,
default="tests/datastore_examples/mdp/config.yaml",
help="Path for the datastore config",
help="Path to the configuration for neural-lam",
)
parser.add_argument(
"--graph",
"--name",
type=str,
default="multiscale",
help="Graph to plot (default: multiscale)",
help="Name of saved graph to plot (default: multiscale)",
)
parser.add_argument(
"--save",
Expand All @@ -43,12 +38,15 @@ def main():
)

args = parser.parse_args()
_, datastore = load_config_and_datastore(
config_path=args.datastore_config_path
)

assert (
args.config_path is not None
), "Specify your config with --config_path"

_, datastore = load_config_and_datastore(config_path=args.config_path)

# Load graph data
graph_dir_path = os.path.join(datastore.root_path, "graph", args.graph)
graph_dir_path = os.path.join(datastore.root_path, "graphs", args.name)
hierarchical, graph_ldict = utils.load_graph(graph_dir_path=graph_dir_path)
(g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
Expand All @@ -63,12 +61,18 @@ def main():

# Extract values needed, turn to numpy
grid_pos = utils.get_reordered_grid_pos(datastore).numpy()
# Add in z-dimension
z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],))
grid_scale = np.ptp(grid_pos)

# Add in z-dimension for grid
z_grid = np.zeros((grid_pos.shape[0],)) # Grid sits at z=0
grid_pos = np.concatenate(
(grid_pos, np.expand_dims(z_grid, axis=1)), axis=1
)

# Compute z-coordinate height of mesh nodes
mesh_base_height = 0.05 * grid_scale
mesh_level_height_diff = 0.1 * grid_scale

# List of edges to plot, (edge_index, from_pos, to_pos, color,
# line_width, label)
edge_plot_list = []
Expand All @@ -79,8 +83,8 @@ def main():
np.concatenate(
(
level_static_features.numpy(),
MESH_HEIGHT
+ MESH_LEVEL_DIST
mesh_base_height
+ mesh_level_height_diff
* height_level
* np.ones((level_static_features.shape[0], 1)),
),
Expand Down Expand Up @@ -170,7 +174,8 @@ def main():
mesh_pos = mesh_static_features.numpy()

mesh_degrees = pyg.utils.degree(m2m_edge_index[1]).numpy()
z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees
# 1% higher per neighbor
z_mesh = (1 + 0.01 * mesh_degrees) * mesh_base_height
mesh_node_size = mesh_degrees / 2

mesh_pos = np.concatenate(
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def get_reordered_grid_pos(datastore):
"""
Interior nodes first, then boundary
"""
xy_np = datastore.get_xy() # np, (num_grid, 2)
xy_np = datastore.get_xy("state") # np, (num_grid, 2)
xy_torch = torch.tensor(xy_np, dtype=torch.float32)

da_boundary_mask = datastore.boundary_mask
Expand Down

0 comments on commit ac9c69d

Please sign in to comment.