From 65347b9613d823c81641966cb2d257d0d9ba7662 Mon Sep 17 00:00:00 2001
From: joeloskarsson <joel.oskarsson@liu.se>
Date: Wed, 13 Nov 2024 13:32:16 +0100
Subject: [PATCH] Linting

---
 neural_lam/build_graph.py |  8 +++++---
 neural_lam/vis.py         | 27 ++++++++++++++++++++++-----
 plot_graph.py             |  6 +-----
 3 files changed, 28 insertions(+), 13 deletions(-)

diff --git a/neural_lam/build_graph.py b/neural_lam/build_graph.py
index dcbff49..c13dc62 100644
--- a/neural_lam/build_graph.py
+++ b/neural_lam/build_graph.py
@@ -41,7 +41,8 @@ def main(input_args=None):
         "--archetype",
         type=str,
         default="keisler",
-        help="Archetype to use to create graph (keisler/graphcast/hierarchical)",
+        help="Archetype to use to create graph "
+        "(keisler/graphcast/hierarchical)",
     )
     parser.add_argument(
         "--mesh_node_distance",
@@ -53,7 +54,8 @@ def main(input_args=None):
         "--level_refinement_factor",
         type=float,
         default=3,
-        help="Refinement factor between grid points and bottom level of mesh hierarchy",
+        help="Refinement factor between grid points and bottom level of "
+        "mesh hierarchy",
     )
     parser.add_argument(
         "--max_num_levels",
@@ -144,7 +146,7 @@ def main(input_args=None):
                 wmg.save.to_pyg(
                     graph=graph,
                     name=component,
-                    list_from_attribute="dummy", # Note: Needed to output list
+                    list_from_attribute="dummy",  # Note: Needed to output list
                     edge_features=["len", "vdiff"],
                     output_directory=args.output_dir,
                 )
diff --git a/neural_lam/vis.py b/neural_lam/vis.py
index fc2ddaa..4dc46ac 100644
--- a/neural_lam/vis.py
+++ b/neural_lam/vis.py
@@ -71,7 +71,7 @@ def plot_on_axis(
     vmax=None,
     ax_title=None,
     cmap="plasma",
-    grid_limits=None
+    grid_limits=None,
 ):
     """
     Plot weather state on given axis
@@ -94,7 +94,7 @@ def plot_on_axis(
         vmin=vmin,
         vmax=vmax,
         cmap=cmap,
-        extent=grid_limits
+        extent=grid_limits,
     )
 
     if ax_title:
@@ -104,7 +104,13 @@ def plot_on_axis(
 
 @matplotlib.rc_context(utils.fractional_plot_bundle(1))
 def plot_prediction(
-    pred, target, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None
+    pred,
+    target,
+    data_config,
+    obs_mask=None,
+    title=None,
+    vrange=None,
+    grid_limits=None,
 ):
     """
     Plot example prediction and grond truth.
@@ -126,7 +132,9 @@ def plot_prediction(
 
     # Plot pred and target
     for ax, data in zip(axes, (target, pred)):
-        im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits)
+        im = plot_on_axis(
+            ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits
+        )
 
     # Ticks and labels
     axes[0].set_title("Ground Truth", size=15)
@@ -160,7 +168,16 @@ def plot_spatial_error(
         subplot_kw={"projection": data_config.coords_projection},
     )
 
-    im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd", grid_limits=grid_limits)
+    im = plot_on_axis(
+        ax,
+        error,
+        data_config,
+        obs_mask,
+        vmin,
+        vmax,
+        cmap="OrRd",
+        grid_limits=grid_limits,
+    )
 
     # Ticks and labels
     cbar = fig.colorbar(im, aspect=30)
diff --git a/plot_graph.py b/plot_graph.py
index b938bea..94def49 100644
--- a/plot_graph.py
+++ b/plot_graph.py
@@ -47,11 +47,7 @@ def main():
 
     # Load graph data
     hierarchical, graph_ldict = utils.load_graph(args.graph)
-    (
-        g2m_edge_index,
-        m2g_edge_index,
-        m2m_edge_index,
-    ) = (
+    (g2m_edge_index, m2g_edge_index, m2m_edge_index,) = (
         graph_ldict["g2m_edge_index"],
         graph_ldict["m2g_edge_index"],
         graph_ldict["m2m_edge_index"],