Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 13, 2024
1 parent aac1ff3 commit 65347b9
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 13 deletions.
8 changes: 5 additions & 3 deletions neural_lam/build_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def main(input_args=None):
"--archetype",
type=str,
default="keisler",
help="Archetype to use to create graph (keisler/graphcast/hierarchical)",
help="Archetype to use to create graph "
"(keisler/graphcast/hierarchical)",
)
parser.add_argument(
"--mesh_node_distance",
Expand All @@ -53,7 +54,8 @@ def main(input_args=None):
"--level_refinement_factor",
type=float,
default=3,
help="Refinement factor between grid points and bottom level of mesh hierarchy",
help="Refinement factor between grid points and bottom level of "
"mesh hierarchy",
)
parser.add_argument(
"--max_num_levels",
Expand Down Expand Up @@ -144,7 +146,7 @@ def main(input_args=None):
wmg.save.to_pyg(
graph=graph,
name=component,
list_from_attribute="dummy", # Note: Needed to output list
list_from_attribute="dummy", # Note: Needed to output list
edge_features=["len", "vdiff"],
output_directory=args.output_dir,
)
Expand Down
27 changes: 22 additions & 5 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def plot_on_axis(
vmax=None,
ax_title=None,
cmap="plasma",
grid_limits=None
grid_limits=None,
):
"""
Plot weather state on given axis
Expand All @@ -94,7 +94,7 @@ def plot_on_axis(
vmin=vmin,
vmax=vmax,
cmap=cmap,
extent=grid_limits
extent=grid_limits,
)

if ax_title:
Expand All @@ -104,7 +104,13 @@ def plot_on_axis(

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
pred, target, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None
pred,
target,
data_config,
obs_mask=None,
title=None,
vrange=None,
grid_limits=None,
):
"""
Plot example prediction and grond truth.
Expand All @@ -126,7 +132,9 @@ def plot_prediction(

# Plot pred and target
for ax, data in zip(axes, (target, pred)):
im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits)
im = plot_on_axis(
ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits
)

# Ticks and labels
axes[0].set_title("Ground Truth", size=15)
Expand Down Expand Up @@ -160,7 +168,16 @@ def plot_spatial_error(
subplot_kw={"projection": data_config.coords_projection},
)

im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd", grid_limits=grid_limits)
im = plot_on_axis(
ax,
error,
data_config,
obs_mask,
vmin,
vmax,
cmap="OrRd",
grid_limits=grid_limits,
)

# Ticks and labels
cbar = fig.colorbar(im, aspect=30)
Expand Down
6 changes: 1 addition & 5 deletions plot_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,7 @@ def main():

# Load graph data
hierarchical, graph_ldict = utils.load_graph(args.graph)
(
g2m_edge_index,
m2g_edge_index,
m2m_edge_index,
) = (
(g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
graph_ldict["g2m_edge_index"],
graph_ldict["m2g_edge_index"],
graph_ldict["m2m_edge_index"],
Expand Down

0 comments on commit 65347b9

Please sign in to comment.