diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py index dcbff49..c13dc62 100644 --- a/neural_lam/build_graph.py +++ b/neural_lam/build_graph.py @@ -41,7 +41,8 @@ def main(input_args=None): "--archetype", type=str, default="keisler", - help="Archetype to use to create graph (keisler/graphcast/hierarchical)", + help="Archetype to use to create graph " + "(keisler/graphcast/hierarchical)", ) parser.add_argument( "--mesh_node_distance", @@ -53,7 +54,8 @@ def main(input_args=None): "--level_refinement_factor", type=float, default=3, - help="Refinement factor between grid points and bottom level of mesh hierarchy", + help="Refinement factor between grid points and bottom level of " + "mesh hierarchy", ) parser.add_argument( "--max_num_levels", @@ -144,7 +146,7 @@ def main(input_args=None): wmg.save.to_pyg( graph=graph, name=component, - list_from_attribute="dummy", # Note: Needed to output list + list_from_attribute="dummy", # Note: Needed to output list edge_features=["len", "vdiff"], output_directory=args.output_dir, ) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index fc2ddaa..4dc46ac 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -71,7 +71,7 @@ def plot_on_axis( vmax=None, ax_title=None, cmap="plasma", - grid_limits=None + grid_limits=None, ): """ Plot weather state on given axis @@ -94,7 +94,7 @@ def plot_on_axis( vmin=vmin, vmax=vmax, cmap=cmap, - extent=grid_limits + extent=grid_limits, ) if ax_title: @@ -104,7 +104,13 @@ def plot_on_axis( @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, target, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None + pred, + target, + data_config, + obs_mask=None, + title=None, + vrange=None, + grid_limits=None, ): """ Plot example prediction and grond truth. @@ -126,7 +132,9 @@ def plot_prediction( # Plot pred and target for ax, data in zip(axes, (target, pred)): - im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits) + im = plot_on_axis( + ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits + ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) @@ -160,7 +168,16 @@ def plot_spatial_error( subplot_kw={"projection": data_config.coords_projection}, ) - im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd", grid_limits=grid_limits) + im = plot_on_axis( + ax, + error, + data_config, + obs_mask, + vmin, + vmax, + cmap="OrRd", + grid_limits=grid_limits, + ) # Ticks and labels cbar = fig.colorbar(im, aspect=30) diff --git a/plot_graph.py b/plot_graph.py index b938bea..94def49 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -47,11 +47,7 @@ def main(): # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) - ( - g2m_edge_index, - m2g_edge_index, - m2m_edge_index, - ) = ( + (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = ( graph_ldict["g2m_edge_index"], graph_ldict["m2g_edge_index"], graph_ldict["m2m_edge_index"],