Skip to content

Commit

Permalink
linting fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Denby committed Aug 14, 2024
1 parent eac6e35 commit 799d55e
Show file tree
Hide file tree
Showing 30 changed files with 625 additions and 541 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ repos:
rev: v1.7.5
hooks:
- id: docformatter
args: [--in-place, --recursive]
args: [--in-place, --recursive, --config, ./pyproject.toml]
additional_dependencies: [tomli]
62 changes: 16 additions & 46 deletions neural_lam/create_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,7 @@ def plot_graph(graph, title=None):
# TODO: indicate direction of directed edges

# Move all to cpu and numpy, compute (in)-degrees
degrees = (
pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy()
)
degrees = pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy()
edge_index = edge_index.cpu().numpy()
pos = pos.cpu().numpy()

Expand Down Expand Up @@ -82,9 +80,7 @@ def sort_nodes_internally(nx_graph):


def save_edges(graph, name, base_path):
torch.save(
graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")
)
torch.save(graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt"))
edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(
torch.float32
) # Save as float32
Expand All @@ -97,9 +93,7 @@ def save_edges_list(graphs, name, base_path):
os.path.join(base_path, f"{name}_edge_index.pt"),
)
edge_features = [
torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(
torch.float32
)
torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to(torch.float32)
for graph in graphs
] # Save as float32
torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt"))
Expand Down Expand Up @@ -130,11 +124,7 @@ def mk_2d_graph(xy, nx, ny):
# add diagonal edges
g.add_edges_from(
[((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
+ [
((x + 1, y), (x, y + 1))
for x in range(nx - 1)
for y in range(ny - 1)
]
+ [((x + 1, y), (x, y + 1)) for x in range(nx - 1) for y in range(ny - 1)]
)

# turn into directed graph
Expand Down Expand Up @@ -164,8 +154,7 @@ def create_graph(
hierarchical: bool,
create_plot: bool,
):
"""Create graph components from `xy` grid coordinates and store in
`graph_dir_path`.
"""Create graph components from `xy` grid coordinates and store in `graph_dir_path`.
Creates the following files for all graphs:
- g2m_edge_index.pt [2, N_g2m_edges]
Expand Down Expand Up @@ -225,6 +214,7 @@ def create_graph(
Returns
-------
None
"""
os.makedirs(graph_dir_path, exist_ok=True)

Expand Down Expand Up @@ -262,10 +252,7 @@ def create_graph(

if hierarchical:
# Relabel nodes of each level with level index first
G = [
prepend_node_index(graph, level_i)
for level_i, graph in enumerate(G)
]
G = [prepend_node_index(graph, level_i) for level_i, graph in enumerate(G)]

num_nodes_level = np.array([len(g_level.nodes) for g_level in G])
# First node index in each level in the hierarchical graph
Expand Down Expand Up @@ -307,9 +294,7 @@ def create_graph(
# add edge from mesh to grid
G_down.add_edge(u, v)
d = np.sqrt(
np.sum(
(G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2
)
np.sum((G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2)
)
G_down.edges[u, v]["len"] = d
G_down.edges[u, v]["vdiff"] = (
Expand All @@ -334,14 +319,10 @@ def create_graph(
down_graphs.append(pyg_down)

if create_plot:
plot_graph(
pyg_down, title=f"Down graph, {from_level} -> {to_level}"
)
plot_graph(pyg_down, title=f"Down graph, {from_level} -> {to_level}")
plt.show()

plot_graph(
pyg_down, title=f"Up graph, {to_level} -> {from_level}"
)
plot_graph(pyg_down, title=f"Up graph, {to_level} -> {from_level}")
plt.show()

# Save up and down edges
Expand Down Expand Up @@ -426,9 +407,7 @@ def create_graph(
vm = G_bottom_mesh.nodes
vm_xy = np.array([xy for _, xy in vm.data("pos")])
# distance between mesh nodes
dm = np.sqrt(
np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2)
)
dm = np.sqrt(np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2))

# grid nodes
Ny, Nx = xy.shape[1:]
Expand Down Expand Up @@ -470,13 +449,9 @@ def create_graph(
u = vg_list[i]
# add edge from grid to mesh
G_g2m.add_edge(u, v)
d = np.sqrt(
np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2)
)
d = np.sqrt(np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2))
G_g2m.edges[u, v]["len"] = d
G_g2m.edges[u, v]["vdiff"] = (
G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]
)
G_g2m.edges[u, v]["vdiff"] = G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]

pyg_g2m = from_networkx(G_g2m)

Expand Down Expand Up @@ -505,13 +480,9 @@ def create_graph(
u = vm_list[i]
# add edge from mesh to grid
G_m2g.add_edge(u, v)
d = np.sqrt(
np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2)
)
d = np.sqrt(np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2))
G_m2g.edges[u, v]["len"] = d
G_m2g.edges[u, v]["vdiff"] = (
G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]
)
G_m2g.edges[u, v]["vdiff"] = G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]

# relabel nodes to integers (sorted)
G_m2g_int = networkx.convert_node_labels_to_integers(
Expand Down Expand Up @@ -578,8 +549,7 @@ def cli(input_args=None):
"--plot",
type=int,
default=0,
help="If graphs should be plotted during generation "
"(default: 0 (false))",
help="If graphs should be plotted during generation " "(default: 0 (false))",
)
parser.add_argument(
"--levels",
Expand Down
Loading

0 comments on commit 799d55e

Please sign in to comment.