Skip to content

Commit

Permalink
pre_commits
Browse files Browse the repository at this point in the history
  • Loading branch information
sadamov committed May 28, 2024
1 parent 6685e94 commit 6423fdf
Show file tree
Hide file tree
Showing 9 changed files with 69 additions and 33 deletions.
10 changes: 8 additions & 2 deletions create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,11 @@ def mk_2d_graph(xy, nx, ny):

# add diagonal edges
g.add_edges_from(
[((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
[
((x, y), (x + 1, y + 1))
for x in range(nx - 1)
for y in range(ny - 1)
]
+ [
((x + 1, y), (x, y + 1))
for x in range(nx - 1)
Expand Down Expand Up @@ -343,7 +347,9 @@ def main():
.reshape(int(n / nx) ** 2, 2)
)
ij = [tuple(x) for x in ij]
G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij)))
G[lev] = networkx.relabel_nodes(
G[lev], dict(zip(G[lev].nodes, ij))
)
G_tot = networkx.compose(G_tot, G[lev])

# Relabel mesh nodes to start with 0
Expand Down
25 changes: 19 additions & 6 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __init__(self, args):
self.output_std = bool(args.output_std)
if self.output_std:
# Pred. dim. in grid cell
self.grid_output_dim = 2 * self.config_loader.num_data_vars("state")
self.grid_output_dim = 2 * self.config_loader.num_data_vars(
"state"
)
else:
# Pred. dim. in grid cell
self.grid_output_dim = self.config_loader.num_data_vars("state")
Expand Down Expand Up @@ -87,7 +89,9 @@ def __init__(self, args):
self.spatial_loss_maps = []

# Load normalization statistics
self.normalization_stats = self.config_loader.load_normalization_stats()
self.normalization_stats = (
self.config_loader.load_normalization_stats()
)
if self.normalization_stats is not None:
for (
var_name,
Expand Down Expand Up @@ -236,7 +240,11 @@ def training_step(self, batch):

log_dict = {"train_loss": batch_loss}
self.log_dict(
log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True
log_dict,
prog_bar=True,
on_step=True,
on_epoch=True,
sync_dist=True,
)
return batch_loss

Expand Down Expand Up @@ -362,7 +370,8 @@ def test_step(self, batch, batch_idx):
):
# Need to plot more example predictions
n_additional_examples = min(
prediction.shape[0], self.n_example_pred - self.plotted_examples
prediction.shape[0],
self.n_example_pred - self.plotted_examples,
)

self.plot_examples(
Expand Down Expand Up @@ -584,10 +593,14 @@ def on_test_epoch_end(self):
)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
pdf_loss_maps_dir = os.path.join(
wandb.run.dir, "spatial_loss_maps"
)
os.makedirs(pdf_loss_maps_dir, exist_ok=True)
for t_i, fig in zip(self.args.val_steps_log, pdf_loss_map_figs):
fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf"))
fig.savefig(
os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")
)
# save mean spatial loss as .pt file also
torch.save(
mean_spatial_loss.cpu(),
Expand Down
9 changes: 4 additions & 5 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
dim=-1,
)

# Embed all features
grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h)
# Embed all features # (B, num_grid_nodes, d_h)
grid_emb = self.grid_embedder(grid_features)
g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h)
m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h)
mesh_emb = self.embedd_mesh_nodes()
Expand Down Expand Up @@ -149,9 +149,8 @@ def predict_step(self, prev_state, prev_prev_state, forcing):
) # (B, num_grid_nodes, d_h)

# Map to output dimension, only for grid
net_output = self.output_map(
grid_rep
) # (B, num_grid_nodes, d_grid_out)
# (B, num_grid_nodes, d_grid_out)
net_output = self.output_map(grid_rep)

if self.output_std:
pred_delta_mean, pred_std_raw = net_output.chunk(
Expand Down
4 changes: 3 additions & 1 deletion neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@ def __init__(self, args):

# Define sub-models
# Feature embedders for mesh
self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end)
self.mesh_embedder = utils.make_mlp(
[mesh_dim] + self.mlp_blueprint_end
)
self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end)

# GNNs
Expand Down
19 changes: 11 additions & 8 deletions neural_lam/models/hi_lam.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,8 @@ def mesh_down_step(
reversed(same_gnns[:-1]),
):
# Extract representations
send_node_rep = mesh_rep_levels[
level_l + 1
] # (B, N_mesh[l+1], d_h)
# (B, N_mesh[l+1], d_h)
send_node_rep = mesh_rep_levels[level_l + 1]
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
down_edge_rep = mesh_down_rep[level_l]
same_edge_rep = mesh_same_rep[level_l]
Expand Down Expand Up @@ -139,9 +138,8 @@ def mesh_up_step(
zip(up_gnns, same_gnns[1:]), start=1
):
# Extract representations
send_node_rep = mesh_rep_levels[
level_l - 1
] # (B, N_mesh[l-1], d_h)
# (B, N_mesh[l-1], d_h)
send_node_rep = mesh_rep_levels[level_l - 1]
rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h)
up_edge_rep = mesh_up_rep[level_l - 1]
same_edge_rep = mesh_same_rep[level_l]
Expand Down Expand Up @@ -183,7 +181,11 @@ def hi_processor_step(
self.mesh_up_same_gnns,
):
# Down
mesh_rep_levels, mesh_same_rep, mesh_down_rep = self.mesh_down_step(
(
mesh_rep_levels,
mesh_same_rep,
mesh_down_rep,
) = self.mesh_down_step(
mesh_rep_levels,
mesh_same_rep,
mesh_down_rep,
Expand All @@ -200,5 +202,6 @@ def hi_processor_step(
up_same_gnns,
)

# Note: We return all, even though only down edges really are used later
# Note: We return all, even though only down edges really are used
# later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
11 changes: 7 additions & 4 deletions neural_lam/models/hi_lam_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ def __init__(self, args):
+ list(self.mesh_down_edge_index)
)
total_edge_index = torch.cat(total_edge_index_list, dim=1)
self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list]
self.edge_split_sections = [
ei.shape[1] for ei in total_edge_index_list
]

if args.processor_layers == 0:
self.processor = lambda x, edge_attr: (x, edge_attr)
Expand Down Expand Up @@ -86,11 +88,12 @@ def hi_processor_step(

mesh_same_rep = mesh_edge_rep_sections[: self.num_levels]
mesh_up_rep = mesh_edge_rep_sections[
self.num_levels : self.num_levels + (self.num_levels - 1)
self.num_levels : self.num_levels + (self.num_levels - 1) # noqa
]
mesh_down_rep = mesh_edge_rep_sections[
self.num_levels + (self.num_levels - 1) :
self.num_levels + (self.num_levels - 1) : # noqa
] # Last are down edges

# Note: We return all, even though only down edges really are used later
# Note: We return all, even though only down edges really are used
# later
return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep
7 changes: 5 additions & 2 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def load_graph(graph_name, device="cpu"):
graph_dir_path = os.path.join("graphs", graph_name)

def loads_file(fn):
return torch.load(os.path.join(graph_dir_path, fn), map_location=device)
return torch.load(
os.path.join(graph_dir_path, fn), map_location=device
)

# Load edges (edge_index)
m2m_edge_index = BufferList(
Expand All @@ -53,7 +55,8 @@ def loads_file(fn):
hierarchical = n_levels > 1 # Nor just single level mesh graph

# Load static edge features
m2m_features = loads_file("m2m_features.pt") # List of (M_m2m[l], d_edge_f)
# List of (M_m2m[l], d_edge_f)
m2m_features = loads_file("m2m_features.pt")
g2m_features = loads_file("g2m_features.pt") # (M_g2m, d_edge_f)
m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f)

Expand Down
13 changes: 10 additions & 3 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,9 @@ def __init__(

self.state = self.config_loader.process_dataset("state", self.split)
assert self.state is not None, "State dataset not found"
self.forcing = self.config_loader.process_dataset("forcing", self.split)
self.forcing = self.config_loader.process_dataset(
"forcing", self.split
)
self.boundary = self.config_loader.process_dataset(
"boundary", self.split
)
Expand Down Expand Up @@ -69,7 +71,10 @@ def __init__(
method="nearest",
)
.pad(
time=(self.boundary_window // 2, self.boundary_window // 2),
time=(
self.boundary_window // 2,
self.boundary_window // 2,
),
mode="edge",
)
.rolling(time=self.boundary_window, center=True)
Expand All @@ -87,7 +92,9 @@ def __getitem__(self, idx):
)

forcing = (
self.forcing_windowed.isel(time=slice(idx + 2, idx + self.ar_steps))
self.forcing_windowed.isel(
time=slice(idx + 2, idx + self.ar_steps)
)
.stack(variable_window=("variable", "window"))
.values
if self.forcing is not None
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@ version = "0.1.0"
packages = ["neural_lam"]

[tool.black]
line-length = 80
line-length = 79

[tool.isort]
default_section = "THIRDPARTY"
default_section = "THIRDPARTY" #codespell:ignore
profile = "black"
# Headings
import_heading_stdlib = "Standard library"
Expand Down

0 comments on commit 6423fdf

Please sign in to comment.