diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 00000000..a6ad84f1 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,33 @@ +name: Run pre-commit job + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + pre-commit-job: + runs-on: ubuntu-latest + defaults: + run: + shell: bash -l {0} + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: 3.9 + - name: Install pre-commit hooks + run: | + pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \ + --index-url https://download.pytorch.org/whl/cpu + pip install -r requirements.txt + pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \ + torch-cluster==1.6.1 torch-geometric==2.3.1 \ + -f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html + - name: Run pre-commit hooks + run: | + pre-commit run --all-files diff --git a/.gitignore b/.gitignore index 7806b590..5ca89369 100644 --- a/.gitignore +++ b/.gitignore @@ -8,6 +8,7 @@ graphs sweeps test_*.sh lightning_logs +.vscode ### Python ### # Byte-compiled / optimized / DLL files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..f48eca67 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,51 @@ +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: check-ast + - id: check-case-conflict + - id: check-docstring-first + - id: check-symlinks + - id: check-toml + - id: check-yaml + - id: debug-statements + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: local + hooks: + - id: codespell + name: codespell + description: Check for spelling errors + language: system + entry: codespell +- repo: local + hooks: + - id: black + name: black + description: Format Python code + language: system + entry: black + types_or: [python, pyi] +- repo: local + hooks: + - id: isort + name: isort + description: Group and sort Python imports + language: system + entry: isort + types_or: [python, pyi, cython] +- repo: local + hooks: + - id: flake8 + name: flake8 + description: Check Python code for correctness, consistency and adherence to best practices + language: system + entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503 + types: [python] +- repo: local + hooks: + - id: pylint + name: pylint + entry: pylint -rn -sn + language: system + types: [python] diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index a8220581..00000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "name": "Python: Current File", - "type": "python", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal", - "justMyCode": true, - "args": [ - "--dataset", - "cosmo", - "--plot", - "1" - ] - } - ] -} \ No newline at end of file diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 68146b39..00000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "diffEditor.ignoreTrimWhitespace": true -} \ No newline at end of file diff --git a/README.md b/README.md index 879380d6..86630683 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,7 @@

+ Neural-LAM is a repository of graph-based neural weather prediction models for Limited Area Modeling (LAM). The code uses [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/pytorch-lightning). Graph Neural Networks are implemented using [PyG](https://pyg.org/) and logging is set up through [Weights & Biases](https://wandb.ai/). @@ -11,16 +12,18 @@ The repository contains LAM versions of: * GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794). * The hierarchical model from [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370). -For more information see our preprint: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370). +For more information see our paper: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370). If you use Neural-LAM in your work, please cite: ``` -@article{oskarsson2023graphbased, - title={Graph-based Neural Weather Prediction for Limited Area Modeling}, - author={Joel Oskarsson and Tomas Landelius and Fredrik Lindsten}, - year={2023}, - journal={arXiv preprint arXiv:2309.17370} +@inproceedings{oskarsson2023graphbased, + title={Graph-based Neural Weather Prediction for Limited Area Modeling}, + author={Oskarsson, Joel and Landelius, Tomas and Lindsten, Fredrik}, + booktitle={NeurIPS 2023 Workshop on Tackling Climate Change with Machine Learning}, + year={2023} } ``` +As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper. +See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces the workshop paper. We plan to continue updating this repository as we improve existing models and develop new ones. Collaborations around this implementation are very welcome. @@ -47,10 +50,11 @@ mamba env create -f environment.yml mamba activate neural-lam # Run the preprocessing/training scripts +# (don't execute preprocessing scripts at the same time as training) sbatch slurm_train.sh # Run the evaluation script and generate plots and gif for TQV -# (don't execute preprocessing scripts at the same time as training) +# (by default this will use the pre-trained model from `wandb/example.ckpt`) sbatch slurm_eval.sh ``` @@ -101,9 +105,9 @@ Note that only the cuda version is pinned to 11.8, otherwise all the latest libr \ -Follow the steps below to create the neccesary python environment. +Follow the steps below to create the necessary python environment. -1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is neccesary for the Cartopy requirement. +1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement. 2. Use python 3.9. 3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system. 4. Install required packages specified in `requirements.txt`. @@ -233,7 +237,7 @@ python train_model.py --model hi_lam --graph hierarchical ... ``` ### Hi-LAM-Parallel -A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in paralell. +A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in parallel. Not included in the paper as initial experiments showed worse results than Hi-LAM, but could be interesting to try in more settings. To train Hi-LAM-Parallel use @@ -343,6 +347,16 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels). Entries 0 in these lists describe edges between the lowest levels 1 and 2. +# Development and Contributing +Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks. +These hooks will run a series of checks on the code, like formatting and linting. +If any of these checks fail the push or PR will be rejected. +To test whether your code passes these checks before pushing, run +``` bash +pre-commit run --all-files +``` +from the root directory of the repository. + # Contact If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/create_grid_features.py b/create_grid_features.py index fe266751..a28d8df8 100644 --- a/create_grid_features.py +++ b/create_grid_features.py @@ -1,47 +1,62 @@ +# Standard library import os from argparse import ArgumentParser +# Third-party import numpy as np import torch def main(): - parser = ArgumentParser(description='Training arguments') - parser.add_argument('--dataset', type=str, default="meps_example", - help='Dataset to compute weights for (default: meps_example)') + """ + Pre-compute all static features related to the grid nodes + """ + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset to compute weights for (default: meps_example)", + ) args = parser.parse_args() static_dir_path = os.path.join("data", args.dataset, "static") # -- Static grid node features -- - grid_xy = torch.tensor(np.load(os.path.join(static_dir_path, "nwp_xy.npy") - )) # (2, N_x, N_y) + grid_xy = torch.tensor( + np.load(os.path.join(static_dir_path, "nwp_xy.npy")) + ) # (2, N_x, N_y) grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2) pos_max = torch.max(torch.abs(grid_xy)) grid_xy = grid_xy / pos_max # Divide by maximum coordinate geopotential = torch.tensor( np.load( - os.path.join( - static_dir_path, - "reference_geopotential_pressure.npy"))) # (N_x, N_y) - geopotential = geopotential.flatten(0, 1) # (N_grid, N_static) + os.path.join(static_dir_path, "reference_geopotential_pressure.npy") + ) + ) # (N_x, N_y, N_fields) + geopotential = geopotential.flatten(0, 1) # (N_grid, N_fields) + gp_min = torch.min(geopotential) + gp_max = torch.max(geopotential) + # Rescale geopotential to [0,1] + geopotential = (geopotential - gp_min) / ( + gp_max - gp_min + ) # (N_grid, N_fields) grid_border_mask = torch.tensor( - np.load( - os.path.join( - static_dir_path, - "border_mask.npy")), - dtype=torch.int64) # (N_x, N_y) - grid_border_mask = grid_border_mask.flatten(0, 1).to( - torch.float).unsqueeze(1) # (N_grid, 1) + np.load(os.path.join(static_dir_path, "border_mask.npy")), + dtype=torch.int64, + ) # (N_x, N_y) + grid_border_mask = ( + grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1) + ) # (N_grid, 1) # Concatenate grid features - grid_features = torch.cat((grid_xy, geopotential, grid_border_mask), - dim=1) # (N_grid, 3 + N_static) + grid_features = torch.cat( + (grid_xy, geopotential, grid_border_mask), dim=1 + ) # (N_grid, 4) - torch.save(grid_features, os.path.join( - static_dir_path, "grid_features.pt")) + torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt")) if __name__ == "__main__": diff --git a/create_mesh.py b/create_mesh.py index 5be04897..7fbb7d83 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -1,6 +1,8 @@ +# Standard library import os from argparse import ArgumentParser +# Third-party import matplotlib import matplotlib.pyplot as plt import networkx @@ -10,10 +12,12 @@ import torch_geometric as pyg from torch_geometric.utils.convert import from_networkx +# First-party from neural_lam import constants + def plot_graph(graph, title=None): - fig, axis = plt.subplots(dpi=200) # W,H + fig, axis = plt.subplots(figsize=(8, 8), dpi=200) # W,H edge_index = graph.edge_index pos = graph.pos @@ -23,35 +27,43 @@ def plot_graph(graph, title=None): if pyg.utils.is_undirected(edge_index): # Keep only 1 direction of edge_index - edge_index = edge_index[:, edge_index[0] < edge_index[1]] # (2, M/2) + edge_index = edge_index[:, edge_index[0] < edge_index[1]] # (2, M/2) # TODO: indicate direction of directed edges # Move all to cpu and numpy, compute (in)-degrees - # Flatten the edge_index to consider both source and target nodes - flattened_edge_index = edge_index.flatten() - # Calculate degrees by counting occurrences of each node index - degrees = np.bincount(flattened_edge_index, minlength=pos.shape[0]) + degrees = ( + pyg.utils.degree(edge_index[1], num_nodes=pos.shape[0]).cpu().numpy() + ) edge_index = edge_index.cpu().numpy() pos = pos.cpu().numpy() # Plot edges - from_pos = pos[edge_index[0]] # (M/2, 2) - to_pos = pos[edge_index[1]] # (M/2, 2) + from_pos = pos[edge_index[0]] # (M/2, 2) + to_pos = pos[edge_index[1]] # (M/2, 2) edge_lines = np.stack((from_pos, to_pos), axis=1) - axis.add_collection(matplotlib.collections.LineCollection(edge_lines, lw=0.4, - colors="black", zorder=1)) + axis.add_collection( + matplotlib.collections.LineCollection( + edge_lines, lw=0.4, colors="black", zorder=1 + ) + ) # Plot nodes node_scatter = axis.scatter( pos[:, 0], pos[:, 1], - c=degrees, s=3, marker="o", zorder=2, cmap="viridis", clim=None) + c=degrees, + s=3, + marker="o", + zorder=2, + cmap="viridis", + clim=None, + ) plt.colorbar(node_scatter, aspect=50) margin = 0.5 - axis.set_xlim(left=0-margin, right=constants.grid_shape[0] + margin) - axis.set_ylim(bottom=0-margin, top=constants.grid_shape[1] + margin) + axis.set_xlim(left=0 - margin, right=constants.GRID_SHAPE[0] + margin) + axis.set_ylim(bottom=0 - margin, top=constants.GRID_SHAPE[1] + margin) if title is not None: axis.set_title(title) @@ -61,7 +73,8 @@ def plot_graph(graph, title=None): def sort_nodes_internally(nx_graph): # For some reason the networkx .nodes() return list can not be sorted, - # but this is the ordering used by pyg when converting. This function fixes this + # but this is the ordering used by pyg when converting. + # This function fixes this. H = networkx.DiGraph() H.add_nodes_from(sorted(nx_graph.nodes(data=True))) H.add_edges_from(nx_graph.edges(data=True)) @@ -69,19 +82,26 @@ def sort_nodes_internally(nx_graph): def save_edges(graph, name, base_path): - torch.save(graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt")) - edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), - dim=1).to(torch.float32) # Save as float32 + torch.save( + graph.edge_index, os.path.join(base_path, f"{name}_edge_index.pt") + ) + edge_features = torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( + torch.float32 + ) # Save as float32 torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) def save_edges_list(graphs, name, base_path): - torch.save([graph.edge_index for graph in graphs], - os.path.join(base_path, f"{name}_edge_index.pt")) + torch.save( + [graph.edge_index for graph in graphs], + os.path.join(base_path, f"{name}_edge_index.pt"), + ) edge_features = [ - torch.cat( - (graph.len.unsqueeze(1), graph.vdiff), dim=1).to( - torch.float32) for graph in graphs] # Save as float32 + torch.cat((graph.len.unsqueeze(1), graph.vdiff), dim=1).to( + torch.float32 + ) + for graph in graphs + ] # Save as float32 torch.save(edge_features, os.path.join(base_path, f"{name}_features.pt")) @@ -105,28 +125,27 @@ def mk_2d_graph(xy, nx, ny): g = networkx.grid_2d_graph(len(ly), len(lx)) for node in g.nodes: - g.nodes[node]['pos'] = np.array([mg[0][node], mg[1][node]]) + g.nodes[node]["pos"] = np.array([mg[0][node], mg[1][node]]) # add diagonal edges - g.add_edges_from([ - ((x, y), (x + 1, y + 1)) - for x in range(nx - 1) - for y in range(ny - 1) - ] + [ - ((x + 1, y), (x, y + 1)) - for x in range(nx - 1) - for y in range(ny - 1) - ]) + g.add_edges_from( + [((x, y), (x + 1, y + 1)) for x in range(nx - 1) for y in range(ny - 1)] + + [ + ((x + 1, y), (x, y + 1)) + for x in range(nx - 1) + for y in range(ny - 1) + ] + ) # turn into directed graph dg = networkx.DiGraph(g) - for (u, v) in g.edges(): - d = np.sqrt(np.sum((g.nodes[u]['pos'] - g.nodes[v]['pos'])**2)) - dg.edges[u, v]['len'] = d - dg.edges[u, v]['vdiff'] = g.nodes[u]['pos'] - g.nodes[v]['pos'] + for u, v in g.edges(): + d = np.sqrt(np.sum((g.nodes[u]["pos"] - g.nodes[v]["pos"]) ** 2)) + dg.edges[u, v]["len"] = d + dg.edges[u, v]["vdiff"] = g.nodes[u]["pos"] - g.nodes[v]["pos"] dg.add_edge(v, u) - dg.edges[v, u]['len'] = d - dg.edges[v, u]['vdiff'] = g.nodes[v]['pos'] - g.nodes[u]['pos'] + dg.edges[v, u]["len"] = d + dg.edges[v, u]["vdiff"] = g.nodes[v]["pos"] - g.nodes[u]["pos"] return dg @@ -139,20 +158,39 @@ def prepend_node_index(graph, new_index): def main(): - parser = ArgumentParser(description='Graph genreation arguments') + parser = ArgumentParser(description="Graph generation arguments") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset to load grid point coordinates from " + "(default: meps_example)", + ) + parser.add_argument( + "--graph", + type=str, + default="multiscale", + help="Name to save graph as (default: multiscale)", + ) parser.add_argument( - '--dataset', type=str, default="meps_example", - help='Dataset to load grid point coordinates from (default: meps_example)') - parser.add_argument('--graph', type=str, default="multiscale", - help='Name to save graph as (default: multiscale)') + "--plot", + type=int, + default=0, + help="If graphs should be plotted during generation " + "(default: 0 (false))", + ) parser.add_argument( - '--plot', type=int, default=0, - help='If graphs should be plotted during generation (default: 0 (false))') - parser.add_argument('--levels', type=int, - help='Limit multi-scale mesh to given number of levels, ' - 'from bottom up (default: None (no limit))') - parser.add_argument('--hierarchical', type=int, default=0, - help='Generate hierarchical mesh graph (default: 0, no)') + "--levels", + type=int, + help="Limit multi-scale mesh to given number of levels, " + "from bottom up (default: None (no limit))", + ) + parser.add_argument( + "--hierarchical", + type=int, + default=0, + help="Generate hierarchical mesh graph (default: 0, no)", + ) args = parser.parse_args() # Load grid positions @@ -170,7 +208,7 @@ def main(): # # graph geometry - nx = constants.graph_num_children # number of children = nx**2 + nx = constants.GRAPH_NUM_CHILDREN # number of children = nx**2 nlev = int(np.log(max(xy.shape)) / np.log(nx)) nleaf = nx**nlev # leaves at the bottom = nleaf**2 @@ -195,24 +233,27 @@ def main(): if args.hierarchical: # Relabel nodes of each level with level index first - G = [prepend_node_index(graph, level_i) for level_i, graph in enumerate(G)] + G = [ + prepend_node_index(graph, level_i) + for level_i, graph in enumerate(G) + ] num_nodes_level = np.array([len(g_level.nodes) for g_level in G]) - # First node index in each level in the hierarcical graph - first_index_level = np.concatenate(( - np.zeros(1, dtype=int), - np.cumsum(num_nodes_level[:-1]))) + # First node index in each level in the hierarchical graph + first_index_level = np.concatenate( + (np.zeros(1, dtype=int), np.cumsum(num_nodes_level[:-1])) + ) # Create inter-level mesh edges up_graphs = [] down_graphs = [] for from_level, to_level, G_from, G_to, start_index in zip( - range(1, mesh_levels), - range(0, mesh_levels - 1), - G[1:], - G[:-1], - first_index_level[:mesh_levels - 1]): - + range(1, mesh_levels), + range(0, mesh_levels - 1), + G[1:], + G[:-1], + first_index_level[: mesh_levels - 1], + ): # start out from graph at from level G_down = G_from.copy() G_down.clear_edges() @@ -225,34 +266,38 @@ def main(): # order in vm should be same as in vm_xy v_to_list = list(G_to.nodes) v_from_list = list(G_from.nodes) - v_from_xy = np.array([xy for _, xy in G_from.nodes.data('pos')]) + v_from_xy = np.array([xy for _, xy in G_from.nodes.data("pos")]) kdt_m = scipy.spatial.KDTree(v_from_xy) # add edges from mesh to grid for v in v_to_list: # find 1(?) nearest neighbours (index to vm_xy) - neigh_idx = kdt_m.query(G_down.nodes[v]['pos'], 1)[1] + neigh_idx = kdt_m.query(G_down.nodes[v]["pos"], 1)[1] u = v_from_list[neigh_idx] # add edge from mesh to grid G_down.add_edge(u, v) d = np.sqrt( np.sum( - (G_down.nodes[u]['pos'] - - G_down.nodes[v]['pos'])**2)) - G_down.edges[u, v]['len'] = d - G_down.edges[u, - v]['vdiff'] = G_down.nodes[u]['pos'] - G_down.nodes[v]['pos'] + (G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"]) ** 2 + ) + ) + G_down.edges[u, v]["len"] = d + G_down.edges[u, v]["vdiff"] = ( + G_down.nodes[u]["pos"] - G_down.nodes[v]["pos"] + ) # relabel nodes to integers (sorted) G_down_int = networkx.convert_node_labels_to_integers( - G_down, first_label=start_index, ordering='sorted') # Issue with sorting here + G_down, first_label=start_index, ordering="sorted" + ) # Issue with sorting here G_down_int = sort_nodes_internally(G_down_int) pyg_down = from_networkx_with_start_index(G_down_int, start_index) # Create up graph, invert downwards edges - up_edges = torch.stack((pyg_down.edge_index[1], pyg_down.edge_index[0]), - dim=0) + up_edges = torch.stack( + (pyg_down.edge_index[1], pyg_down.edge_index[0]), dim=0 + ) pyg_up = pyg_down.clone() pyg_up.edge_index = up_edges @@ -260,17 +305,15 @@ def main(): down_graphs.append(pyg_down) if args.plot: - plot_graph(pyg_down, title=f"Down graph, {from_level} -> {to_level}") - plt.savefig( - os.path.join( - graph_dir_path, - f"mesh_down_graph_{from_level}.png")) - - plot_graph(pyg_down, title=f"Up graph, {to_level} -> {from_level}") - plt.savefig( - os.path.join( - graph_dir_path, - f"mesh_up_graph_{to_level}.png")) + plot_graph( + pyg_down, title=f"Down graph, {from_level} -> {to_level}" + ) + plt.show() + + plot_graph( + pyg_down, title=f"Up graph, {to_level} -> {from_level}" + ) + plt.show() # Save up and down edges save_edges_list(up_graphs, "mesh_up", graph_dir_path) @@ -280,9 +323,12 @@ def main(): m2m_graphs = [ from_networkx_with_start_index( networkx.convert_node_labels_to_integers( - level_graph, first_label=start_index, ordering='sorted'), - start_index) for level_graph, start_index in zip( - G, first_index_level)] + level_graph, first_label=start_index, ordering="sorted" + ), + start_index, + ) + for level_graph, start_index in zip(G, first_index_level) + ] mesh_pos = [graph.pos.to(torch.float32) for graph in m2m_graphs] @@ -298,10 +344,11 @@ def main(): for lev in range(1, len(G)): nodes = list(G[lev - 1].nodes) n = int(np.sqrt(len(nodes))) - ij = np.array(nodes).reshape( - (n, n, 2))[ - 1:: nx, 1:: nx, :].reshape( - int(n / nx) ** 2, 2) + ij = ( + np.array(nodes) + .reshape((n, n, 2))[1::nx, 1::nx, :] + .reshape(int(n / nx) ** 2, 2) + ) ij = [tuple(x) for x in ij] G[lev] = networkx.relabel_nodes(G[lev], dict(zip(G[lev].nodes, ij))) G_tot = networkx.compose(G_tot, G[lev]) @@ -310,8 +357,9 @@ def main(): G_tot = prepend_node_index(G_tot, 0) # relabel nodes to integers (sorted) - G_int = networkx.convert_node_labels_to_integers(G_tot, first_label=0, - ordering='sorted') + G_int = networkx.convert_node_labels_to_integers( + G_tot, first_label=0, ordering="sorted" + ) # Graph to use in g2m and m2g G_bottom_mesh = G_tot @@ -333,8 +381,9 @@ def main(): mesh_pos = [pos / pos_max for pos in mesh_pos] # Save mesh positions - torch.save(mesh_pos, os.path.join(graph_dir_path, - "mesh_features.pt")) # mesh pos, in float32 + torch.save( + mesh_pos, os.path.join(graph_dir_path, "mesh_features.pt") + ) # mesh pos, in float32 # # Grid2Mesh @@ -346,9 +395,11 @@ def main(): # mesh nodes on lowest level vm = G_bottom_mesh.nodes - vm_xy = np.array([xy for _, xy in vm.data('pos')]) + vm_xy = np.array([xy for _, xy in vm.data("pos")]) # distance between mesh nodes - dm = np.sqrt(np.sum((vm.data('pos')[(0, 1, 0)] - vm.data('pos')[(0, 0, 0)])**2)) + dm = np.sqrt( + np.sum((vm.data("pos")[(0, 1, 0)] - vm.data("pos")[(0, 0, 0)]) ** 2) + ) # grid nodes Ny, Nx = xy.shape[1:] @@ -359,10 +410,10 @@ def main(): # vg features (only pos introduced here) for node in G_grid.nodes: # pos is in feature but here explicit for convenience - G_grid.nodes[node]['pos'] = np.array([xy[0][node], xy[1][node]]) + G_grid.nodes[node]["pos"] = np.array([xy[0][node], xy[1][node]]) - # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes (i,j) - # and impose sorting order such that vm are the first nodes + # add 1000 to node key to separate grid nodes (1000,i,j) from mesh nodes + # (i,j) and impose sorting order such that vm are the first nodes G_grid = prepend_node_index(G_grid, 1000) # build kd tree for grid point pos @@ -385,14 +436,18 @@ def main(): # add edges for v in vm: # find neighbours (index to vg_xy) - neigh_idxs = kdt_g.query_ball_point(vm[v]['pos'], dm * DM_SCALE) + neigh_idxs = kdt_g.query_ball_point(vm[v]["pos"], dm * DM_SCALE) for i in neigh_idxs: u = vg_list[i] # add edge from grid to mesh G_g2m.add_edge(u, v) - d = np.sqrt(np.sum((G_g2m.nodes[u]['pos'] - G_g2m.nodes[v]['pos'])**2)) - G_g2m.edges[u, v]['len'] = d - G_g2m.edges[u, v]['vdiff'] = G_g2m.nodes[u]['pos'] - G_g2m.nodes[v]['pos'] + d = np.sqrt( + np.sum((G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"]) ** 2) + ) + G_g2m.edges[u, v]["len"] = d + G_g2m.edges[u, v]["vdiff"] = ( + G_g2m.nodes[u]["pos"] - G_g2m.nodes[v]["pos"] + ) pyg_g2m = from_networkx(G_g2m) @@ -416,18 +471,23 @@ def main(): # add edges from mesh to grid for v in vg_list: # find 4 nearest neighbours (index to vm_xy) - neigh_idxs = kdt_m.query(G_m2g.nodes[v]['pos'], 4)[1] + neigh_idxs = kdt_m.query(G_m2g.nodes[v]["pos"], 4)[1] for i in neigh_idxs: u = vm_list[i] # add edge from mesh to grid G_m2g.add_edge(u, v) - d = np.sqrt(np.sum((G_m2g.nodes[u]['pos'] - G_m2g.nodes[v]['pos'])**2)) - G_m2g.edges[u, v]['len'] = d - G_m2g.edges[u, v]['vdiff'] = G_m2g.nodes[u]['pos'] - G_m2g.nodes[v]['pos'] + d = np.sqrt( + np.sum((G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"]) ** 2) + ) + G_m2g.edges[u, v]["len"] = d + G_m2g.edges[u, v]["vdiff"] = ( + G_m2g.nodes[u]["pos"] - G_m2g.nodes[v]["pos"] + ) # relabel nodes to integers (sorted) - G_m2g_int = networkx.convert_node_labels_to_integers(G_m2g, first_label=0, - ordering='sorted') + G_m2g_int = networkx.convert_node_labels_to_integers( + G_m2g, first_label=0, ordering="sorted" + ) pyg_m2g = from_networkx(G_m2g_int) if args.plot: diff --git a/create_parameter_weights.py b/create_parameter_weights.py index a11c8781..f9cab328 100644 --- a/create_parameter_weights.py +++ b/create_parameter_weights.py @@ -1,33 +1,61 @@ +# Standard library import os from argparse import ArgumentParser +# Third-party import numpy as np import torch from tqdm import tqdm +# First-party from neural_lam import constants from neural_lam.weather_dataset import WeatherDataset def main(): - parser = ArgumentParser(description='Training arguments') - parser.add_argument('--dataset', type=str, default="meps_example", - help='Dataset to compute weights for (default: meps_example)') - parser.add_argument('--batch_size', type=int, default=32, - help='Batch size when iterating over the dataset') - parser.add_argument('--n_workers', type=int, default=4, - help='Number of workers in data loader (default: 4)') + """ + Pre-compute parameter weights to be used in loss function + """ + parser = ArgumentParser(description="Training arguments") + parser.add_argument( + "--dataset", + type=str, + default="meps_example", + help="Dataset to compute weights for (default: meps_example)", + ) + parser.add_argument( + "--batch_size", + type=int, + default=32, + help="Batch size when iterating over the dataset", + ) + parser.add_argument( + "--step_length", + type=int, + default=1, + help="Step length in hours to consider single time step (default: 1)", + ) + parser.add_argument( + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", + ) args = parser.parse_args() static_dir_path = os.path.join("data", args.dataset, "static") - # Define weights for each vertical level and parameter # Create parameter weights based on height w_list = [] - for var_name, pw in zip(constants.param_names_short, - constants.param_weights.values()): + for var_name, pw in zip( + constants.PARAM_NAMES_SHORT, constants.PARAM_WEIGHTS.values() + ): # Determine the levels to iterate over - levels = constants.level_weights.values() if constants.is_3d[var_name] else [1] + levels = ( + constants.LEVEL_WEIGHTS.values() + if constants.IS_3D[var_name] + else [1] + ) # Iterate over the levels for lw in levels: @@ -37,33 +65,56 @@ def main(): print("Saving parameter weights...") np.save( - os.path.join( - static_dir_path, - 'parameter_weights.npy'), - w_list.astype('float32')) + os.path.join(static_dir_path, "parameter_weights.npy"), + w_list.astype("float32"), + ) # Load dataset without any subsampling ds = WeatherDataset( args.dataset, split="train", - standardize=False) # Without standardization - loader = torch.utils.data.DataLoader(ds, args.batch_size, shuffle=False, - num_workers=args.n_workers) - # Compute mean and std.-dev. of each parameter (+ flux forcing) across full dataset + standardize=False, + ) # Without standardization + loader = torch.utils.data.DataLoader( + ds, args.batch_size, shuffle=False, num_workers=args.n_workers + ) + # Compute mean and std.-dev. of each parameter (+ flux forcing) + # across full dataset print("Computing mean and std.-dev. for parameters...") + means = [] squares = [] - for init_batch, target_batch in tqdm(loader): - batch = torch.cat((init_batch, target_batch), - dim=1) # (N_batch, N_t, N_grid, d_features) + flux_means = [] + flux_squares = [] + for batch_data in tqdm(loader): + if constants.GRID_FORCING_DIM > 0: + init_batch, target_batch, forcing_batch = batch_data + flux_batch = forcing_batch[:, :, :, 0] # Flux is first index + flux_means.append(torch.mean(flux_batch)) # (,) + flux_squares.append(torch.mean(flux_batch**2)) # (,) + else: + init_batch, target_batch = batch_data + + batch = torch.cat( + (init_batch, target_batch), dim=1 + ) # (N_batch, N_t, N_grid, d_features) means.append(torch.mean(batch, dim=(1, 2))) # (N_batch, d_features,) - # (N_batch, d_features,) - squares.append(torch.mean(batch**2, dim=(1, 2))) + squares.append( + torch.mean(batch**2, dim=(1, 2)) + ) # (N_batch, d_features,) mean = torch.mean(torch.cat(means, dim=0), dim=0) # (d_features) second_moment = torch.mean(torch.cat(squares, dim=0), dim=0) std = torch.sqrt(second_moment - mean**2) # (d_features) + if constants.GRID_FORCING_DIM > 0: + flux_mean = torch.mean(torch.stack(flux_means)) # (,) + flux_second_moment = torch.mean(torch.stack(flux_squares)) # (,) + flux_std = torch.sqrt(flux_second_moment - flux_mean**2) # (,) + flux_stats = torch.stack((flux_mean, flux_std)) + + print("Saving mean flux_stats...") + torch.save(flux_stats, os.path.join(static_dir_path, "flux_stats.pt")) print("Saving mean, std.-dev...") torch.save(mean, os.path.join(static_dir_path, "parameter_mean.pt")) torch.save(std, os.path.join(static_dir_path, "parameter_std.pt")) @@ -73,20 +124,31 @@ def main(): ds_standard = WeatherDataset( args.dataset, split="train", - standardize=True) # Re-load with standardization + standardize=True, + ) # Re-load with standardization loader_standard = torch.utils.data.DataLoader( - ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers) + ds_standard, args.batch_size, shuffle=False, num_workers=args.n_workers + ) diff_means = [] diff_squares = [] - for init_batch, target_batch, in tqdm(loader_standard): + for batch_data in tqdm(loader_standard): + if constants.GRID_FORCING_DIM > 0: + init_batch, target_batch, forcing_batch = batch_data + flux_batch = forcing_batch[:, :, :, 0] # Flux is first index + flux_means.append(torch.mean(flux_batch)) # (,) + flux_squares.append(torch.mean(flux_batch**2)) # (,) + else: + init_batch, target_batch = batch_data batch_diffs = init_batch[:, 1:] - target_batch # (N_batch', N_t-1, N_grid, d_features) - diff_means.append(torch.mean(batch_diffs, dim=(1, 2)) - ) # (N_batch', d_features,) - diff_squares.append(torch.mean(batch_diffs**2, - dim=(1, 2))) # (N_batch', d_features,) + diff_means.append( + torch.mean(batch_diffs, dim=(1, 2)) + ) # (N_batch', d_features,) + diff_squares.append( + torch.mean(batch_diffs**2, dim=(1, 2)) + ) # (N_batch', d_features,) diff_mean = torch.mean(torch.cat(diff_means, dim=0), dim=0) # (d_features) diff_second_moment = torch.mean(torch.cat(diff_squares, dim=0), dim=0) diff --git a/create_static_features.py b/create_static_features.py index f8cf5a2e..cbab9259 100644 --- a/create_static_features.py +++ b/create_static_features.py @@ -1,56 +1,93 @@ +# Standard library from argparse import ArgumentParser +# Third-party import numpy as np import xarray as xr +# First-party from neural_lam import constants def main(): - parser = ArgumentParser(description='Static features arguments') - parser.add_argument('--xdim', type=str, default="x_1", - help='Name of the x-dimension in the dataset (default: x_1)') - parser.add_argument('--ydim', type=str, default="y_1", - help='Name of the x-dimension in the dataset (default: y_1)') - parser.add_argument('--zdim', type=str, default="z_1", - help='Name of the x-dimension in the dataset (default: z_1)') + """Create the static features for the neural network.""" + parser = ArgumentParser(description="Static features arguments") parser.add_argument( - '--field_names', nargs="+", default=["hsurf", "FI", "P0FL"], - help='Names of the fields to extract from the .nc file (default: ["hsurf", "FI", "P0FL"])'), + "--xdim", + type=str, + default="x_1", + help="Name of the x-dimension in the dataset (default: x_1)", + ) parser.add_argument( - '--boundaries', type=int, default=30, - help='Number of grid-cells closest to each boundary to mask (default: 30)') + "--ydim", + type=str, + default="y_1", + help="Name of the x-dimension in the dataset (default: y_1)", + ) parser.add_argument( - '--outdir', type=str, default="data/cosmo/static/", - help='Output directory for the static features (default: data/cosmo/static/)') + "--zdim", + type=str, + default="z_1", + help="Name of the x-dimension in the dataset (default: z_1)", + ) + parser.add_argument( + "--field_names", + nargs="+", + default=["hsurf", "FI", "P0FL"], + help=( + "Names of the fields to extract from the .nc file " + '(default: ["hsurf", "FI", "P0FL"])' + ), + ) + parser.add_argument( + "--boundaries", + type=int, + default=30, + help=( + "Number of grid-cells closest to each boundary to mask " + "(default: 30)" + ), + ) + parser.add_argument( + "--outdir", + type=str, + default="data/cosmo/static/", + help=( + "Output directory for the static features " + "(default: data/cosmo/static/)" + ), + ) args = parser.parse_args() - # Open the .nc file - ds = xr.open_zarr(constants.example_file).isel(time=0) + ds = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) np_fields = [] for var_name in args.field_names: # scale the variable to [0, 1] - ds[var_name] = (ds[var_name] - ds[var_name].min() - ) / (ds[var_name].max() - ds[var_name].min()) + ds[var_name] = (ds[var_name] - ds[var_name].min()) / ( + ds[var_name].max() - ds[var_name].min() + ) if args.zdim not in ds[var_name].dims: field_2d = ds[var_name].transpose(args.xdim, args.ydim).values # add a dummy dimension np_fields.append(np.expand_dims(field_2d, axis=-1)) else: - np_fields.append(ds[var_name].sel({args.zdim: constants.vertical_levels}).transpose( - args.xdim, args.ydim, args.zdim).values) - - np_fields = np.concatenate(np_fields, axis=-1) + np_fields.append( + ds[var_name] + .sel({args.zdim: constants.VERTICAL_LEVELS}) + .transpose(args.xdim, args.ydim, args.zdim) + .values + ) + np_fields = np.concatenate(np_fields, axis=-1) # (N_x, N_y, N_fields) # Save the numpy array to a .npy file - np.save(args.outdir + 'reference_geopotential_pressure.npy', np_fields) + np.save(args.outdir + "reference_geopotential_pressure.npy", np_fields) # Get the dimensions of the dataset - dims = ds.dims - x_dim, y_dim = ds.dims[args.xdim], ds.dims[args.ydim] + dims = ds.sizes + x_dim, y_dim = ds.sizes[args.xdim], ds.sizes[args.ydim] # Create a 2D meshgrid for x and y indices x_grid, y_grid = np.indices((x_dim, y_dim)) @@ -58,19 +95,19 @@ def main(): # Stack the 2D arrays into a 3D array with x and y as the first dimension grid_xy = np.stack((y_grid, x_grid)) - np.save(args.outdir + 'nwp_xy.npy', grid_xy) + np.save(args.outdir + "nwp_xy.npy", grid_xy) # (2, N_x, N_y) # Create a mask with the same dimensions, initially set to False mask = np.full((dims[args.xdim], dims[args.ydim]), False) # Set the args.boundaries grid-cells closest to each boundary to True - mask[:args.boundaries, :] = True # top boundary - mask[-args.boundaries:, :] = True # bottom boundary - mask[:, :args.boundaries] = True # left boundary - mask[:, -args.boundaries:] = True # right boundary + mask[: args.boundaries, :] = True # top boundary + mask[-args.boundaries :, :] = True # bottom boundary + mask[:, : args.boundaries] = True # left boundary + mask[:, -args.boundaries :] = True # right boundary # Save the numpy array to a .npy file - np.save(args.outdir + 'border_mask', mask) + np.save(args.outdir + "border_mask", mask) # (N_x, N_y) if __name__ == "__main__": diff --git a/create_zarr_archive.py b/create_zarr_archive.py index cd4d8448..2a81f5b3 100644 --- a/create_zarr_archive.py +++ b/create_zarr_archive.py @@ -5,16 +5,18 @@ import re import shutil -import numcodecs - # Third-party +import numcodecs import xarray as xr from tqdm import tqdm +# First-party from neural_lam import constants -def append_or_create_zarr(data_out: xr.Dataset, config: dict, zarr_name: str) -> None: +def append_or_create_zarr( + data_out: xr.Dataset, config: dict, zarr_name: str +) -> None: """Append data to an existing Zarr archive or create a new one.""" if config["test_year"] in data_out.time.dt.year.values: @@ -41,7 +43,7 @@ def load_data(config: dict) -> None: """Load weather data from NetCDF files and store it in a Zarr archive.""" file_paths = [] - for root, dirs, files in os.walk(data_config["data_path"]): + for root, _, files in os.walk(config["data_path"]): for file in files: full_path = os.path.join(root, file) file_paths.append(full_path) @@ -49,75 +51,75 @@ def load_data(config: dict) -> None: # Group file paths into chunks file_groups = [ - file_paths[i: i + config["chunk_size"]] - for i in range(0, len(file_paths), - config["chunk_size"])] + file_paths[i : i + config["chunk_size"]] + for i in range(0, len(file_paths), config["chunk_size"]) + ] for group in tqdm(file_groups, desc="Processing file groups"): - # Create a new Zarr archive for each group - # Extract the date from the first file in the group - date = os.path.basename(group[0]).split('_')[0][3:] + # Create a new Zarr archive for each group Extract the date from the + # first file in the group + date = os.path.basename(group[0]).split("_")[0][3:] zarr_name = f"data_{date}.zarr" if not os.path.exists( - os.path.join(config["zarr_path"], - "train", zarr_name)) and not os.path.exists( - os.path.join(config["zarr_path"], - "test", zarr_name)): + os.path.join(config["zarr_path"], "train", zarr_name) + ) and not os.path.exists( + os.path.join(config["zarr_path"], "test", zarr_name) + ): for full_path in group: process_file(full_path, config, zarr_name) def process_file(full_path, config, zarr_name): + """Process a single NetCDF file and store it in a Zarr archive.""" try: # if zarr_name directory exists, skip match = config["filename_pattern"].match(full_path) if not match: return None - data: xr.Dataset = xr.open_dataset(full_path, engine="netcdf4", chunks={ - "time": 1, - "x_1": -1, - "y_1": -1, - "z_1": -1, - "zbound": -1, - }, autoclose=True).drop_vars("grid_mapping_1") + data: xr.Dataset = xr.open_dataset( + full_path, + engine="netcdf4", + chunks={ + "time": 1, + "x_1": -1, + "y_1": -1, + "z_1": -1, + "zbound": -1, + }, + autoclose=True, + ).drop_vars("grid_mapping_1") for var in data.variables: data[var].encoding = {"compressor": config["compressor"]} - data.time.encoding = {'dtype': 'float64'} + data.time.encoding = {"dtype": "float64"} append_or_create_zarr(data, config, zarr_name) # Display the progress print(f"Processed: {full_path}") except (FileNotFoundError, OSError) as e: print(f"Error: {e}") + return None def combine_zarr_archives(config) -> None: - """Combine the last Zarr archive from the train folder with the first from the test - folder.""" + """Combine the last Zarr archive from the train folder with the first from + the test folder.""" # Get the last Zarr archive from the train folder train_archives = sorted( - glob.glob( - os.path.join( - data_config["zarr_path"], - "train", - '*.zarr'))) + glob.glob(os.path.join(config["zarr_path"], "train", "*.zarr")) + ) # Get the first Zarr archive from the test folder test_archives = sorted( - glob.glob( - os.path.join( - data_config["zarr_path"], - "test", - '*.zarr'))) + glob.glob(os.path.join(config["zarr_path"], "test", "*.zarr")) + ) first_test_archive = xr.open_zarr(test_archives[0], consolidated=True) - val_archives_path = os.path.join(data_config["zarr_path"], "val") + val_archives_path = os.path.join(config["zarr_path"], "val") for t in range(first_test_archive.time.size): - first_test_archive.isel(time=slice(t, t + 1)).to_zarr(train_archives[-1], - mode="a", - append_dim="time", - consolidated=True) + first_test_archive.isel(time=slice(t, t + 1)).to_zarr( + train_archives[-1], mode="a", append_dim="time", consolidated=True + ) shutil.rmtree(test_archives[0]) shutil.rmtree(test_archives[-1]) @@ -128,35 +130,45 @@ def combine_zarr_archives(config) -> None: if __name__ == "__main__": - - parser = argparse.ArgumentParser(description='Create a zarr archive.') + parser = argparse.ArgumentParser(description="Create a zarr archive.") + parser.add_argument( + "--data_path", + type=str, + required=True, + help="Path to the raw data", + default="/scratch/mch/sadamov/ml_v1/", + ) + parser.add_argument( + "--test_year", type=int, required=True, help="Test year", default=2020 + ) parser.add_argument( - '--data_path', + "--filename_regex", type=str, required=True, - help='Path to the raw data', - default="/scratch/mch/sadamov/ml_v1/") - parser.add_argument('--test_year', type=int, required=True, - help='Test year', default=2020) - parser.add_argument('--filename_regex', type=str, required=True, - help='Filename regex', default="(.*)_extr.nc") + help="Filename regex", + default="(.*)_extr.nc", + ) args = parser.parse_args() data_config = { "data_path": args.data_path, "filename_regex": args.filename_regex, - "zarr_path": "/users/sadamov/pyprojects/neural-cosmo/data/cosmo/samples", + "zarr_path": ( + "/users/sadamov/pyprojects/" "neural-cosmo/data/cosmo/samples" + ), "compressor": numcodecs.Blosc( - cname='lz4', - clevel=7, - shuffle=numcodecs.Blosc.SHUFFLE), - "chunk_size": constants.chunk_size, + cname="lz4", clevel=7, shuffle=numcodecs.Blosc.SHUFFLE + ), + "chunk_size": constants.CHUNK_SIZE, "test_year": args.test_year, } data_config.update( - {"folders": os.listdir(data_config["data_path"]), - "filename_pattern": re.compile(data_config["filename_regex"])}) + { + "folders": os.listdir(data_config["data_path"]), + "filename_pattern": re.compile(data_config["filename_regex"]), + } + ) load_data(data_config) combine_zarr_archives(data_config) diff --git a/environment.yml b/environment.yml index 2ca64f80..912998d0 100644 --- a/environment.yml +++ b/environment.yml @@ -30,3 +30,9 @@ dependencies: - zarr - pip: - tueplots + - codespell>=2.0.0 + - black>=21.9b0 + - isort>=5.9.3 + - flake8>=4.0.1 + - pylint>=3.0.3 + - pre-commit>=2.15.0 diff --git a/helper.py b/helper.py index b1d9c150..166212b0 100644 --- a/helper.py +++ b/helper.py @@ -1,44 +1,51 @@ +# Standard library import os +# Third-party import xarray as xr -path = "data/cosmo/samples/test/" +PATH = "data/cosmo/samples/test/" # Initialize a dictionary to store the top-1 precipitation event for each file precip_events = {} -for file in os.listdir(path): +for file in os.listdir(PATH): print(file) - ds = xr.open_zarr(os.path.join(path, file)) + ds = xr.open_zarr(os.path.join(PATH, file)) - ds_rechunked = ds.chunk({'time': -1}) - mean_tot_prec = ds_rechunked['TOT_PREC'].mean(dim=['y_1', 'x_1']).compute() + ds_rechunked = ds.chunk({"time": -1}) + mean_tot_prec = ds_rechunked["TOT_PREC"].mean(dim=["y_1", "x_1"]).compute() # Find the maximum precipitation value and its corresponding time max_precip_value = mean_tot_prec.max().item() max_precip_time = mean_tot_prec.where( - mean_tot_prec == max_precip_value, - drop=True).time.values[0] - max_precip_time_str = str(max_precip_time) # Convert to string for dictionary key + mean_tot_prec == max_precip_value, drop=True + ).time.values[0] + MAX_PRECIP_TIME_STR = str( + max_precip_time + ) # Convert to string for dictionary key # Store the top-1 precipitation event in the dictionary precip_events[file] = { - 'max_time': max_precip_time_str, - 'max_value': max_precip_value, + "max_time": MAX_PRECIP_TIME_STR, + "max_value": max_precip_value, } # Find the file with the maximum and minimum precipitation values -max_precip_file = max(precip_events, key=lambda x: precip_events[x]['max_value']) +max_precip_file = max( + precip_events, key=lambda x: precip_events[x]["max_value"] +) max_precip_event = precip_events[max_precip_file] # Sort the precipitation events by maximum value in descending order sorted_precip_events = sorted( - precip_events.items(), - key=lambda x: x[1]['max_value'], - reverse=True) + precip_events.items(), key=lambda x: x[1]["max_value"], reverse=True +) # Print the top ten precipitation events print("Top ten maximum precipitation events:") for i, (file, event) in enumerate(sorted_precip_events[:10]): print( - f"{i+1}: {event['max_time']} with a value of {event['max_value']} in file {file}") + f"{i + 1}: {event['max_time']} with a value of " + f"{event['max_value']} in file {file}" + ) diff --git a/neural_lam/constants.py b/neural_lam/constants.py index bb2f98e2..b6f57c50 100644 --- a/neural_lam/constants.py +++ b/neural_lam/constants.py @@ -1,101 +1,119 @@ +# Third-party import numpy as np from cartopy import crs as ccrs -wandb_project = "neural-lam" - -# Full names -param_names = [ - 'Temperature', - 'Zonal wind component', - 'Meridional wind component', - 'Relative humidity', - 'Pressure at Mean Sea Level', - 'Pressure Perturbation', - 'Surface Pressure', - 'Total Precipitation', - 'Total Water Vapor content', - '2-meter Temperature', - '10-meter Zonal wind speed', - '10-meter Meridional wind speed', +WANDB_PROJECT = "neural-lam" + +SECONDS_IN_YEAR = ( + 365 * 24 * 60 * 60 +) # Assuming no leap years in dataset (2024 is next) + +# Log prediction error for these lead times +VAL_STEP_LOG_ERRORS = np.array([1, 2, 3, 5, 10, 15, 19]) + +# Log these metrics to wandb as scalar values for +# specific variables and lead times +# List of metrics to watch, including any prefix (e.g. val_rmse) +METRICS_WATCH = [] +# Dict with variables and lead times to log watched metrics for +# Format is a dictionary that maps from a variable index to +# a list of lead time steps +VAR_LEADS_METRICS_WATCH = { + 6: [2, 19], # t_2 + 14: [2, 19], # wvint_0 + 15: [2, 19], # z_1000 +} + +# Variable names +PARAM_NAMES = [ + "Temperature", + "Zonal wind component", + "Meridional wind component", + "Relative humidity", + "Pressure at Mean Sea Level", + "Pressure Perturbation", + "Surface Pressure", + "Total Precipitation", + "Total Water Vapor content", + "2-meter Temperature", + "10-meter Zonal wind speed", + "10-meter Meridional wind speed", ] # Short names -param_names_short = [ - 'T', - 'U', - 'V', - 'RELHUM', - 'PMSL', - 'PP', - 'PS', - 'TOT_PREC', - 'TQV', - 'T_2M', - 'U_10M', - 'V_10M', +PARAM_NAMES_SHORT = [ + "T", + "U", + "V", + "RELHUM", + "PMSL", + "PP", + "PS", + "TOT_PREC", + "TQV", + "T_2M", + "U_10M", + "V_10M", ] # Units -param_units = [ - 'K', - 'm/s', - 'm/s', - 'Perc.', - 'Pa', - 'hPa', - 'Pa', - '$kg/m^2$', - '$kg/m^2$', - 'K', - 'm/s', - 'm/s', +PARAM_UNITS = [ + "K", + "m/s", + "m/s", + "Perc.", + "Pa", + "hPa", + "Pa", + "$kg/m^2$", + "$kg/m^2$", + "K", + "m/s", + "m/s", ] # Parameter weights -param_weights = { - 'T': 1, - 'U': 1, - 'V': 1, - 'RELHUM': 1, - 'PMSL': 1, - 'PP': 1, - 'PS': 1, - 'TOT_PREC': 1, - 'TQV': 1, - 'T_2M': 1, - 'U_10M': 1, - 'V_10M': 1, +PARAM_WEIGHTS = { + "T": 1, + "U": 1, + "V": 1, + "RELHUM": 1, + "PMSL": 1, + "PP": 1, + "PS": 1, + "TOT_PREC": 1, + "TQV": 1, + "T_2M": 1, + "U_10M": 1, + "V_10M": 1, } # Vertical levels -vertical_levels = [ - 1, 5, 13, 22, 38, 41, 60 -] +VERTICAL_LEVELS = [1, 5, 13, 22, 38, 41, 60] -param_constraints = { - 'RELHUM': (0, 100), - 'TQV': (0, None), - 'TOT_PREC': (0, None), +PARAM_CONSTRAINTS = { + "RELHUM": (0, 100), + "TQV": (0, None), + "TOT_PREC": (0, None), } -is_3d = { - 'T': 1, - 'U': 1, - 'V': 1, - 'RELHUM': 1, - 'PMSL': 0, - 'PP': 1, - 'PS': 0, - 'TOT_PREC': 0, - 'TQV': 0, - 'T_2M': 0, # TODO: these 2d field diagnostic variables could be removed from input channels, and derived during inference - 'U_10M': 0, - 'V_10M': 0, +IS_3D = { + "T": 1, + "U": 1, + "V": 1, + "RELHUM": 1, + "PMSL": 0, + "PP": 1, + "PS": 0, + "TOT_PREC": 0, + "TQV": 0, + "T_2M": 0, + "U_10M": 0, + "V_10M": 0, } # Vertical level weights -# TODO: exponential function of height -level_weights = { +LEVEL_WEIGHTS = { 1: 1, 5: 1, 13: 1, @@ -106,32 +124,35 @@ } # Projection and grid -grid_shape = (390, 582) # (y, x) +GRID_SHAPE = (390, 582) # (y, x) # Time step prediction during training / prediction (eval) -train_horizon = 3 # hours (t-1 + t -> t+1) -eval_horizon = 25 # hours (autoregressive) +TRAIN_HORIZON = 3 # hours (t-1 + t -> t+1) +EVAL_HORIZON = 25 # hours (autoregressive) # Properties of the Graph / Mesh -graph_num_children = 3 +GRAPH_NUM_CHILDREN = 3 # Log prediction error for these time steps forward -val_step_log_errors = np.arange(1, eval_horizon - 1) -metrics_initialized = False +VAL_STEP_LOG_ERRORS = np.arange(1, EVAL_HORIZON - 1) +METRICS_INITIALIZED = False # Plotting -fig_size = (15, 10) -example_file = "data/cosmo/samples/train/data_2015112800.zarr" -chunk_size = 100 -eval_datetime = "2020100215" -eval_plot_vars = ["TQV"] -store_example_data = False -cosmo_proj = ccrs.PlateCarree() -selected_proj = cosmo_proj -pollon = -170.0 -pollat = 43.0 -smooth_boundaries = False +FIG_SIZE = (15, 10) +EXAMPLE_FILE = "data/cosmo/samples/train/data_2015112800.zarr" +CHUNK_SIZE = 100 +EVAL_DATETIME = "2020100215" +EVAL_PLOT_VARS = ["TQV"] +STORE_EXAMPLE_DATA = False +COSMO_PROJ = ccrs.PlateCarree() +SELECTED_PROJ = COSMO_PROJ +POLLON = -170.0 +POLLAT = 43.0 +SMOOTH_BOUNDARIES = False # Some constants useful for sub-classes -batch_static_feature_dim = 0 -grid_forcing_dim = 0 +BATCH_STATIC_FEATURE_DIM = 0 +GRID_FORCING_DIM = 0 +GRID_STATE_DIM = sum( + len(VERTICAL_LEVELS) if IS_3D[param] else 1 for param in PARAM_NAMES_SHORT +) diff --git a/neural_lam/interaction_net.py b/neural_lam/interaction_net.py index bbb96bac..663f27e4 100644 --- a/neural_lam/interaction_net.py +++ b/neural_lam/interaction_net.py @@ -1,29 +1,48 @@ +# Third-party import torch import torch_geometric as pyg from torch import nn +# First-party from neural_lam import utils class InteractionNet(pyg.nn.MessagePassing): """ - Implementation of a generic Interaction Network, from Battaglia et al. (2016) + Implementation of a generic Interaction Network, + from Battaglia et al. (2016) """ - def __init__(self, edge_index, input_dim, update_edges=True, hidden_layers=1, - hidden_dim=None, edge_chunk_sizes=None, aggr_chunk_sizes=None, aggr="sum"): + # pylint: disable=arguments-differ + # Disable to override args/kwargs from superclass + + def __init__( + self, + edge_index, + input_dim, + update_edges=True, + hidden_layers=1, + hidden_dim=None, + edge_chunk_sizes=None, + aggr_chunk_sizes=None, + aggr="sum", + ): """ Create a new InteractionNet edge_index: (2,M), Edges in pyg format - input_dim: Dimensionality of input representations, for both nodes and edges - update_edges: If new edge representations should be computed and returned + input_dim: Dimensionality of input representations, + for both nodes and edges + update_edges: If new edge representations should be computed + and returned hidden_layers: Number of hidden layers in MLPs - hidden_dim: Dimensionality of hidden layers, if None then same as input_dim - edge_chunk_sizes: List of chunks sizes to split edge representation into and - use separate MLPs for (None = no chunking, same MLP) - aggr_chunk_sizes: List of chunks sizes to split aggregated node representation + hidden_dim: Dimensionality of hidden layers, if None then same + as input_dim + edge_chunk_sizes: List of chunks sizes to split edge representation into and use separate MLPs for (None = no chunking, same MLP) + aggr_chunk_sizes: List of chunks sizes to split aggregated node + representation into and use separate MLPs for + (None = no chunking, same MLP) aggr: Message aggregation method (sum/mean) """ assert aggr in ("sum", "mean"), f"Unknown aggregation method: {aggr}" @@ -37,7 +56,9 @@ def __init__(self, edge_index, input_dim, update_edges=True, hidden_layers=1, edge_index = edge_index - edge_index.min(dim=1, keepdim=True)[0] # Store number of receiver nodes according to edge_index self.num_rec = edge_index[1].max() + 1 - edge_index[0] = edge_index[0] + self.num_rec # Make sender indices after rec + edge_index[0] = ( + edge_index[0] + self.num_rec + ) # Make sender indices after rec self.register_buffer("edge_index", edge_index, persistent=False) # Create MLPs @@ -47,21 +68,25 @@ def __init__(self, edge_index, input_dim, update_edges=True, hidden_layers=1, if edge_chunk_sizes is None: self.edge_mlp = utils.make_mlp(edge_mlp_recipe) else: - self.edge_mlp = SplitMLPs([utils.make_mlp(edge_mlp_recipe) for _ in - edge_chunk_sizes], edge_chunk_sizes) + self.edge_mlp = SplitMLPs( + [utils.make_mlp(edge_mlp_recipe) for _ in edge_chunk_sizes], + edge_chunk_sizes, + ) if aggr_chunk_sizes is None: self.aggr_mlp = utils.make_mlp(aggr_mlp_recipe) else: - self.aggr_mlp = SplitMLPs([utils.make_mlp(aggr_mlp_recipe) for _ in - aggr_chunk_sizes], aggr_chunk_sizes) + self.aggr_mlp = SplitMLPs( + [utils.make_mlp(aggr_mlp_recipe) for _ in aggr_chunk_sizes], + aggr_chunk_sizes, + ) self.update_edges = update_edges def forward(self, send_rep, rec_rep, edge_rep): """ - Apply interaction network to update the representations of receiver nodes, - and optionally the edge representations. + Apply interaction network to update the representations of receiver + nodes, and optionally the edge representations. send_rep: (N_send, d_h), vector representations of sender nodes rec_rep: (N_rec, d_h), vector representations of receiver nodes @@ -69,14 +94,15 @@ def forward(self, send_rep, rec_rep, edge_rep): Returns: rec_rep: (N_rec, d_h), updated vector representations of receiver nodes - (optionally) edge_rep: (M, d_h), updated vector representations of edges + (optionally) edge_rep: (M, d_h), updated vector representations + of edges """ - # Always concatenate to [rec_nodes, send_nodes] for propagation, but only - # aggregate to rec_nodes - # TODO: edge_index to device? - node_reps = torch.cat((rec_rep, send_rep), dim=1) - edge_rep_aggr, edge_diff = self.propagate(self.edge_index, x=node_reps, - edge_attr=edge_rep) + # Always concatenate to [rec_nodes, send_nodes] for propagation, + # but only aggregate to rec_nodes + node_reps = torch.cat((rec_rep, send_rep), dim=-2) + edge_rep_aggr, edge_diff = self.propagate( + self.edge_index, x=node_reps, edge_attr=edge_rep + ) rec_diff = self.aggr_mlp(torch.cat((rec_rep, edge_rep_aggr), dim=-1)) # Residual connections @@ -94,14 +120,15 @@ def message(self, x_j, x_i, edge_attr): """ return self.edge_mlp(torch.cat((edge_attr, x_j, x_i), dim=-1)) - def aggregate(self, messages, index, ptr, dim_size): + # pylint: disable-next=signature-differs + def aggregate(self, inputs, index, ptr, dim_size): """ Overridden aggregation function to: * return both aggregated and original messages, * only aggregate to number of receiver nodes. """ - aggr = super().aggregate(messages, index, ptr, self.num_rec) - return aggr, messages + aggr = super().aggregate(inputs, index, ptr, self.num_rec) + return aggr, inputs class SplitMLPs(nn.Module): @@ -113,8 +140,9 @@ class SplitMLPs(nn.Module): def __init__(self, mlps, chunk_sizes): super().__init__() - assert len(mlps) == len(chunk_sizes), ( - "Number of MLPs must match the number of chunks") + assert len(mlps) == len( + chunk_sizes + ), "Number of MLPs must match the number of chunks" self.mlps = nn.ModuleList(mlps) self.chunk_sizes = chunk_sizes @@ -126,9 +154,10 @@ def forward(self, x): x: (..., N, d), where N = sum(chunk_sizes) Returns: - joined_output: (..., N, d), concatenated results from the different MLPs + joined_output: (..., N, d), concatenated results from the MLPs """ chunks = torch.split(x, self.chunk_sizes, dim=-2) - chunk_outputs = [mlp(chunk_input) - for mlp, chunk_input in zip(self.mlps, chunks)] + chunk_outputs = [ + mlp(chunk_input) for mlp, chunk_input in zip(self.mlps, chunks) + ] return torch.cat(chunk_outputs, dim=-2) diff --git a/neural_lam/metrics.py b/neural_lam/metrics.py new file mode 100644 index 00000000..7db2cca6 --- /dev/null +++ b/neural_lam/metrics.py @@ -0,0 +1,237 @@ +# Third-party +import torch + + +def get_metric(metric_name): + """ + Get a defined metric with given name + + metric_name: str, name of the metric + + Returns: + metric: function implementing the metric + """ + metric_name_lower = metric_name.lower() + assert ( + metric_name_lower in DEFINED_METRICS + ), f"Unknown metric: {metric_name}" + return DEFINED_METRICS[metric_name_lower] + + +def mask_and_reduce_metric(metric_entry_vals, mask, average_grid, sum_vars): + """ + Masks and (optionally) reduces entry-wise metric values + + (...,) is any number of batch dimensions, potentially different + but broadcastable + metric_entry_vals: (..., N, d_state), prediction + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + # Only keep grid nodes in mask + if mask is not None: + metric_entry_vals = metric_entry_vals[ + ..., mask, : + ] # (..., N', d_state) + + # Optionally reduce last two dimensions + if average_grid: # Reduce grid first + metric_entry_vals = torch.mean( + metric_entry_vals, dim=-2 + ) # (..., d_state) + if sum_vars: # Reduce vars second + metric_entry_vals = torch.sum( + metric_entry_vals, dim=-1 + ) # (..., N) or (...,) + + return metric_entry_vals + + +def wmse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + """ + Weighted Mean Squared Error + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + entry_mse = torch.nn.functional.mse_loss( + pred, target, reduction="none" + ) # (..., N, d_state) + entry_mse_weighted = entry_mse / (pred_std**2) # (..., N, d_state) + + return mask_and_reduce_metric( + entry_mse_weighted, + mask=mask, + average_grid=average_grid, + sum_vars=sum_vars, + ) + + +def mse(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + """ + (Unweighted) Mean Squared Error + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + # Replace pred_std with constant ones + return wmse( + pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars + ) + + +def wmae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + """ + Weighted Mean Absolute Error + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + entry_mae = torch.nn.functional.l1_loss( + pred, target, reduction="none" + ) # (..., N, d_state) + entry_mae_weighted = entry_mae / pred_std # (..., N, d_state) + + return mask_and_reduce_metric( + entry_mae_weighted, + mask=mask, + average_grid=average_grid, + sum_vars=sum_vars, + ) + + +def mae(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + """ + (Unweighted) Mean Absolute Error + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + # Replace pred_std with constant ones + return wmae( + pred, target, torch.ones_like(pred_std), mask, average_grid, sum_vars + ) + + +def nll(pred, target, pred_std, mask=None, average_grid=True, sum_vars=True): + """ + Negative Log Likelihood loss, for isotropic Gaussian likelihood + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + # Broadcast pred_std if shaped (d_state,), done internally in Normal class + dist = torch.distributions.Normal(pred, pred_std) # (..., N, d_state) + entry_nll = -dist.log_prob(target) # (..., N, d_state) + + return mask_and_reduce_metric( + entry_nll, mask=mask, average_grid=average_grid, sum_vars=sum_vars + ) + + +def crps_gauss( + pred, target, pred_std, mask=None, average_grid=True, sum_vars=True +): + """ + (Negative) Continuous Ranked Probability Score (CRPS) + Closed-form expression based on Gaussian predictive distribution + + (...,) is any number of batch dimensions, potentially different + but broadcastable + pred: (..., N, d_state), prediction + target: (..., N, d_state), target + pred_std: (..., N, d_state) or (d_state,), predicted std.-dev. + mask: (N,), boolean mask describing which grid nodes to use in metric + average_grid: boolean, if grid dimension -2 should be reduced (mean over N) + sum_vars: boolean, if variable dimension -1 should be reduced (sum + over d_state) + + Returns: + metric_val: One of (...,), (..., d_state), (..., N), (..., N, d_state), + depending on reduction arguments. + """ + std_normal = torch.distributions.Normal( + torch.zeros((), device=pred.device), torch.ones((), device=pred.device) + ) + target_standard = (target - pred) / pred_std # (..., N, d_state) + + entry_crps = -pred_std * ( + torch.pi ** (-0.5) + - 2 * torch.exp(std_normal.log_prob(target_standard)) + - target_standard * (2 * std_normal.cdf(target_standard) - 1) + ) # (..., N, d_state) + + return mask_and_reduce_metric( + entry_crps, mask=mask, average_grid=average_grid, sum_vars=sum_vars + ) + + +DEFINED_METRICS = { + "mse": mse, + "mae": mae, + "wmse": wmse, + "wmae": wmae, + "nll": nll, + "crps_gauss": crps_gauss, +} diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index 2f2eeb72..3b463f61 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,24 +1,32 @@ +# pylint: disable=wrong-import-order +# Standard library import glob import os from datetime import datetime, timedelta +# Third-party import imageio import matplotlib.pyplot as plt import numpy as np import pytorch_lightning as pl import torch +import wandb from torch import nn -import wandb -from neural_lam import constants, utils, vis +# First-party +from neural_lam import constants, metrics, utils, vis +# pylint: disable=too-many-public-methods class ARModel(pl.LightningModule): """ Generic auto-regressive weather model. Abstract class that can be extended. """ + # pylint: disable=arguments-differ + # Disable to override args/kwargs from superclass + def __init__(self, args): super().__init__() @@ -26,50 +34,78 @@ def __init__(self, args): self.lr = args.lr # Log prediction error for these time steps forward - self.val_step_log_errors = constants.val_step_log_errors - self.metrics_initialized = constants.metrics_initialized + self.val_step_log_errors = constants.VAL_STEP_LOG_ERRORS + self.metrics_initialized = constants.METRICS_INITIALIZED # Some constants useful for sub-classes - self.batch_static_feature_dim = constants.batch_static_feature_dim - self.grid_forcing_dim = constants.grid_forcing_dim - count_3d_fields = sum(value == 1 for value in constants.is_3d.values()) - count_2d_fields = sum(value != 1 for value in constants.is_3d.values()) - self.grid_state_dim = len( - constants.vertical_levels) * count_3d_fields + count_2d_fields + self.batch_static_feature_dim = constants.BATCH_STATIC_FEATURE_DIM + self.grid_forcing_dim = constants.GRID_FORCING_DIM + count_3d_fields = sum(value == 1 for value in constants.IS_3D.values()) + count_2d_fields = sum(value != 1 for value in constants.IS_3D.values()) + self.grid_state_dim = ( + len(constants.VERTICAL_LEVELS) * count_3d_fields + count_2d_fields + ) # Load static features for grid/data static_data_dict = utils.load_static_data(args.dataset) for static_data_name, static_data_tensor in static_data_dict.items(): - self.register_buffer(static_data_name, static_data_tensor, persistent=False) - - # MSE loss, need to do reduction ourselves to get proper weighting - self.loss_name = args.loss - if args.loss == "mse": - self.loss = nn.MSELoss(reduction="none") - - inv_var = self.step_diff_std**-2. - state_weight = self.param_weights * inv_var # (d_f,) - elif args.loss == "mae": - self.loss = nn.L1Loss(reduction="none") - - # Weight states with inverse std instead in this case - state_weight = self.param_weights / self.step_diff_std # (d_f,) + self.register_buffer( + static_data_name, static_data_tensor, persistent=False + ) + + # Double grid output dim. to also output std.-dev. + self.output_std = bool(args.output_std) + if self.output_std: + self.grid_output_dim = ( + 2 * constants.GRID_STATE_DIM + ) # Pred. dim. in grid cell else: - assert False, f"Unknown loss function: {args.loss}" - self.register_buffer("state_weight", state_weight, persistent=False) + self.grid_output_dim = ( + constants.GRID_STATE_DIM + ) # Pred. dim. in grid cell + + # Store constant per-variable std.-dev. weighting + # Note that this is the inverse of the multiplicative weighting + # in wMSE/wMAE + self.register_buffer( + "per_var_std", + self.step_diff_std / torch.sqrt(self.param_weights), + persistent=False, + ) + + # grid_dim from data + static + batch_static + ( + self.num_grid_nodes, + grid_static_dim, + ) = self.grid_static_features.shape # 63784 = 268x238 + self.grid_dim = ( + 2 * constants.GRID_STATE_DIM + + grid_static_dim + + constants.GRID_FORCING_DIM + + constants.BATCH_STATIC_FEATURE_DIM + ) + + # Instantiate loss function + self.loss = metrics.get_metric(args.loss) # Pre-compute interior mask for use in loss function - self.interior_mask = 1. - self.border_mask # (N_grid, 1), 1 for non-border - # Number of grid nodes to predict - self.N_interior = torch.sum(self.interior_mask) + self.register_buffer( + "interior_mask", 1.0 - self.border_mask, persistent=False + ) # (num_grid_nodes, 1), 1 for non-border self.step_length = args.step_length # Number of hours per pred. step - self.val_errs = [] - self.test_maes = [] - self.test_mses = [] + self.val_metrics = { + "mse": [], + } + self.test_metrics = { + "mse": [], + "mae": [], + } + if self.output_std: + self.test_metrics["output_std"] = [] # Treat as metric # For making restoring of optimizer state optional - self.resume_opt_sched = args.resume_opt_sched + self.restore_opt = args.restore_opt # For example plotting self.n_example_pred = args.n_example_pred @@ -79,16 +115,18 @@ def __init__(self, args): self.variable_indices = self.precompute_variable_indices() self.selected_vars_units = [ - (var_name, var_unit) for var_name, var_unit in zip( - constants.param_names_short, constants.param_units - ) if var_name in constants.eval_plot_vars + (var_name, var_unit) + for var_name, var_unit in zip( + constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + ) + if var_name in constants.EVAL_PLOT_VARS ] print("variable_indices", self.variable_indices) print("selected_vars_units", self.selected_vars_units) @pl.utilities.rank_zero_only def log_image(self, name, img): - + """Log an image to wandb""" wandb.log({name: wandb.Image(img)}) @pl.utilities.rank_zero_only @@ -103,11 +141,22 @@ def init_metrics(self): self.metrics_initialized = True # Make sure this is done only once def configure_optimizers(self): - opt = torch.optim.AdamW(self.parameters(), lr=self.lr, betas=(0.9, 0.95)) - scheduler = torch.optim.lr_scheduler.StepLR(opt, step_size=30, gamma=0.1) + opt = torch.optim.AdamW( + self.parameters(), lr=self.lr, betas=(0.9, 0.95) + ) + scheduler = torch.optim.lr_scheduler.StepLR( + opt, step_size=30, gamma=0.1 + ) return [opt], [scheduler] + @property + def interior_mask_bool(self): + """ + Get the interior mask as a boolean (N,) mask. + """ + return self.interior_mask[:, 0].to(torch.bool) + @staticmethod def expand_to_batch(x, batch_size): """ @@ -115,18 +164,18 @@ def expand_to_batch(x, batch_size): """ return x.unsqueeze(0).expand(batch_size, -1, -1) - def setup(self, stage=None): - self.loss = self.loss.to(self.device) - self.interior_mask = self.interior_mask.to(self.device) - def precompute_variable_indices(self): + """ + Precompute indices for each variable in the input tensor + """ variable_indices = {} all_vars = [] index = 0 - # Create a list of tuples for all variables, using level 0 for 2D variables - for var_name in constants.param_names_short: - if constants.is_3d[var_name]: - for level in constants.vertical_levels: + # Create a list of tuples for all variables, using level 0 for 2D + # variables + for var_name in constants.PARAM_NAMES_SHORT: + if constants.IS_3D[var_name]: + for level in constants.VERTICAL_LEVELS: all_vars.append((var_name, level)) else: all_vars.append((var_name, 0)) # Use level 0 for 2D variables @@ -144,230 +193,291 @@ def precompute_variable_indices(self): return variable_indices def apply_constraints(self, prediction): - for param, (min_val, max_val) in constants.param_constraints.items(): + """ + Apply constraints to prediction to ensure values are within the + specified bounds + """ + for param, (min_val, max_val) in constants.PARAM_CONSTRAINTS.items(): indices = self.variable_indices[param] for index in indices: - # Apply clamping to ensure values are within the specified bounds + # Apply clamping to ensure values are within the specified + # bounds prediction[:, :, index] = torch.clamp( - prediction[:, :, index], min=min_val, max=max_val if max_val is not None else float('inf')) + prediction[:, :, index], + min=min_val, + max=max_val if max_val is not None else float("inf"), + ) return prediction - def predict_step(self, prev_state, prev_prev_state): + def predict_step( + self, + prev_state, + prev_prev_state, + batch_static_features=None, + forcing=None, + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, N_grid, feature_dim), X_t - prev_prev_state: (B, N_grid, feature_dim), X_{t-1} - batch_static_features: (B, N_grid, batch_static_feature_dim) - forcing: (B, N_grid, forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + batch_static_features: (B, num_grid_nodes, batch_static_feature_dim) + forcing: (B, num_grid_nodes, forcing_dim), optional """ - raise NotImplementedError("No prediction step implemented") - def unroll_prediction(self, init_states, true_states): + def unroll_prediction( + self, + init_states, + true_states, + batch_static_features=None, + forcing_features=None, + ): """ Roll out prediction taking multiple autoregressive steps with model - init_states: (B, 2, N_grid, d_f) - batch_static_features: (B, N_grid, d_static_f) - forcing_features: (B, pred_steps, N_grid, d_static_f) - true_states: (B, pred_steps, N_grid, d_f) + init_states: (B, 2, num_grid_nodes, d_f) + batch_static_features: (B, num_grid_nodes, d_static_f), optional + forcing_features: (B, pred_steps, num_grid_nodes, d_static_f), optional + true_states: (B, pred_steps, num_grid_nodes, d_f) """ - prev_prev_state = init_states[:, 0] prev_state = init_states[:, 1] prediction_list = [] - pred_steps = true_states.shape[1] + pred_std_list = [] + pred_steps = ( + forcing_features.shape[1] + if forcing_features is not None + else true_states.shape[1] + ) for i in range(pred_steps): + forcing = ( + forcing_features[:, i] if forcing_features is not None else None + ) border_state = true_states[:, i] - predicted_state = self.predict_step( - prev_state, - prev_prev_state) # (B, N_grid, d_f) + + pred_state, pred_std = self.predict_step( + prev_state, prev_prev_state, batch_static_features, forcing + ) + # state: (B, num_grid_nodes, d_f) + # pred_std: (B, num_grid_nodes, d_f) or None # Overwrite border with true state - new_state = self.border_mask * border_state +\ - self.interior_mask * predicted_state + new_state = ( + self.border_mask * border_state + + self.interior_mask * pred_state + ) + prediction_list.append(new_state) + if self.output_std: + pred_std_list.append(pred_std) - # Upate conditioning states + # Update conditioning states prev_prev_state = prev_state prev_state = new_state - return torch.stack(prediction_list, dim=1) # (B, pred_steps, N_grid, d_f) - - def weighted_loss(self, prediction, target, reduce_spatial_dim=True): - """ - Computed weighted loss function. - prediction/target: (B, pred_steps, N_grid, d_f) - returns (B, pred_steps) - """ - torch.autograd.set_detect_anomaly(True) - - entry_loss = self.loss(prediction, target) # (B, pred_steps, N_grid, d_f) - - # (B, pred_steps, N_grid), weighted sum over features - grid_node_loss = torch.mean(entry_loss * self.state_weight, dim=-1) - - if not reduce_spatial_dim: - return grid_node_loss # (B, pred_steps, N_grid) - - # Take (unweighted) mean over only non-border (interior) grid nodes - time_step_loss = torch.sum(grid_node_loss * self.interior_mask[:, 0], - dim=-1) / self.N_interior # (B, pred_steps) + prediction = torch.stack( + prediction_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + if self.output_std: + pred_std = torch.stack( + pred_std_list, dim=1 + ) # (B, pred_steps, num_grid_nodes, d_f) + else: + pred_std = self.per_var_std # (d_f,) - return time_step_loss # (B, pred_steps) + return prediction, pred_std def common_step(self, batch): """ Predict on single batch batch = time_series, batch_static_features, forcing_features - init_states: (B, 2, N_grid, d_features) - target_states: (B, pred_steps, N_grid, d_features) - batch_static_features: (B, N_grid, d_static_f), for example open water - forcing_features: (B, pred_steps, N_grid, d_forcing), where index 0 - corresponds to index 1 of init_states + init_states: (B, 2, num_grid_nodes, d_features) + target_states: (B, pred_steps, num_grid_nodes, d_features) + batch_static_features: (B, num_grid_nodes, d_static_f), optional + forcing_features: (B, pred_steps, num_grid_nodes, d_forcing), optional """ + init_states, target_states = batch[:2] + batch_static_features = batch[2] if len(batch) > 2 else None + forcing_features = batch[3] if len(batch) > 3 else None - init_states, target_states, = batch + prediction, pred_std = self.unroll_prediction( + init_states, + target_states, + batch_static_features, + forcing_features, + ) # (B, pred_steps, num_grid_nodes, d_f) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) - prediction = self.unroll_prediction( - init_states, target_states) # (B, pred_steps, N_grid, d_f) - - return prediction, target_states + return prediction, target_states, pred_std def training_step(self, batch): """ Train on single batch """ + prediction, target, pred_std = self.common_step(batch) - prediction, target = self.common_step(batch) # Compute loss - batch_loss = torch.mean(self.weighted_loss( - prediction, target)) # mean over unrolled times and batch + batch_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ) + ) # mean over unrolled times and batch + log_dict = {"train_loss": batch_loss} self.log_dict( - log_dict, - prog_bar=True, - on_step=True, - on_epoch=True, - sync_dist=True) + log_dict, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True + ) return batch_loss - def per_var_error(self, prediction, target, error="mae"): - """ - Computed MAE/MSE per variable and time step - prediction/target: (B, pred_steps, N_grid, d_f) - returns (B, pred_steps) - """ - - if error == "mse": - loss_func = torch.nn.functional.mse_loss - else: - loss_func = torch.nn.functional.l1_loss - entry_loss = loss_func(prediction, target, - reduction="none") # (B, pred_steps, N_grid, d_f) - - mean_error = torch.sum(entry_loss * self.interior_mask, - dim=2) / self.N_interior # (B, pred_steps, d_f) - return mean_error - def all_gather_cat(self, tensor_to_gather): """ - Gather tensors across all ranks, and concatenate across dim. 0 (instead of - stacking in new dim. 0) + Gather tensors across all ranks, and concatenate across dim. 0 + (instead of stacking in new dim. 0) tensor_to_gather: (d1, d2, ...), distributed over K ranks returns: (K*d1, d2, ...) """ - if torch.distributed.is_available() and torch.distributed.is_initialized(): - if torch.distributed.get_world_size() > 1: - tensor_to_gather = self.all_gather(tensor_to_gather).flatten(0, 1) - return tensor_to_gather + return self.all_gather(tensor_to_gather).flatten(0, 1) + # newer lightning versions requires batch_idx argument, even if unused + # pylint: disable-next=unused-argument def validation_step(self, batch, batch_idx): """ Run validation on single batch """ - prediction, target = self.common_step(batch) - - time_step_loss = torch.mean(self.weighted_loss(prediction, - target), dim=0) # (time_steps-1) + prediction, target, pred_std = self.common_step(batch) + + time_step_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ), + dim=0, + ) # (time_steps-1) mean_loss = torch.mean(time_step_loss) # Log loss per time step forward and mean - val_log_dict = {f"val_loss_unroll{step:02}": time_step_loss[step - 1] - for step in self.val_step_log_errors} + val_log_dict = { + f"val_loss_unroll{step}": time_step_loss[step - 1] + for step in constants.VAL_STEP_LOG_ERRORS + } val_log_dict["val_mean_loss"] = mean_loss - - errs = self.per_var_error( - prediction, target, error=self.loss_name) # (B, pred_steps, d_f) - self.val_errs.append(errs) - - self.log_dict(val_log_dict, on_step=False, on_epoch=True, sync_dist=True) + self.log_dict( + val_log_dict, on_step=False, on_epoch=True, sync_dist=True + ) + + # Store MSEs + entry_mses = metrics.mse( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.val_metrics["mse"].append(entry_mses) def on_validation_epoch_end(self): """ Compute val metrics at the end of val epoch """ - val_err_tensor = self.all_gather_cat(torch.cat( - self.val_errs, dim=0)) # (N_val, pred_steps, d_f) + # Create error maps for all test metrics + self.aggregate_and_plot_metrics(self.val_metrics, prefix="val") - if self.trainer.is_global_zero: - val_err_total = torch.mean(val_err_tensor, dim=0) # (pred_steps, d_f) - val_err_rescaled = val_err_total * self.data_std # (pred_steps, d_f) - - if not self.trainer.sanity_checking: - # Don't log this during sanity checking - val_err_fig = vis.plot_error_map( - val_err_rescaled, - self.data_mean, - title="Validation " + - self.loss_name.upper() + - " error", - step_length=self.step_length) - wandb.log({"val_err": wandb.Image(val_err_fig)}) - plt.close("all") - - self.val_errs.clear() # Free memory + # Clear lists with validation metrics values + for metric_list in self.val_metrics.values(): + metric_list.clear() + # pylint: disable-next=unused-argument def test_step(self, batch, batch_idx): """ Run test on single batch """ - - prediction, target = self.common_step(batch) - - time_step_loss = torch.mean(self.weighted_loss(prediction, - target), dim=0) # (time_steps-1) + prediction, target, pred_std = self.common_step(batch) + # prediction: (B, pred_steps, num_grid_nodes, d_f) + # pred_std: (B, pred_steps, num_grid_nodes, d_f) or (d_f,) + + time_step_loss = torch.mean( + self.loss( + prediction, target, pred_std, mask=self.interior_mask_bool + ), + dim=0, + ) # (time_steps-1,) mean_loss = torch.mean(time_step_loss) # Log loss per time step forward and mean - test_log_dict = {f"test_loss_unroll{step:02}": time_step_loss[step - 1] - for step in self.val_step_log_errors} + test_log_dict = { + f"test_loss_unroll{step}": time_step_loss[step - 1] + for step in constants.VAL_STEP_LOG_ERRORS + } test_log_dict["test_mean_loss"] = mean_loss - self.log_dict(test_log_dict, on_step=False, on_epoch=True, sync_dist=True) - - # For error maps - maes = self.per_var_error( - prediction, target, error="mae") # (B, pred_steps, d_f) - self.test_maes.append(maes) - mses = self.per_var_error( - prediction, target, error="mse") # (B, pred_steps, d_f) - self.test_mses.append(mses) + self.log_dict( + test_log_dict, on_step=False, on_epoch=True, sync_dist=True + ) + + # Compute all evaluation metrics for error maps + # Note: explicitly list metrics here, as test_metrics can contain + # additional ones, computed differently, but that should be aggregated + # on_test_epoch_end + for metric_name in ("mse", "mae"): + metric_func = metrics.get_metric(metric_name) + batch_metric_vals = metric_func( + prediction, + target, + pred_std, + mask=self.interior_mask_bool, + sum_vars=False, + ) # (B, pred_steps, d_f) + self.test_metrics[metric_name].append(batch_metric_vals) + + if self.output_std: + # Store output std. per variable, spatially averaged + mean_pred_std = torch.mean( + pred_std[..., self.interior_mask_bool, :], dim=-2 + ) # (B, pred_steps, d_f) + self.test_metrics["output_std"].append(mean_pred_std) # Save per-sample spatial loss for specific times - spatial_loss = self.weighted_loss( - prediction, target, reduce_spatial_dim=False) # (B, pred_steps, N_grid) - log_spatial_losses = spatial_loss[:, self.val_step_log_errors - 1] - self.spatial_loss_maps.append(log_spatial_losses) # (B, N_log, N_grid) + spatial_loss = self.loss( + prediction, target, pred_std, average_grid=False + ) # (B, pred_steps, num_grid_nodes) + log_spatial_losses = spatial_loss[:, constants.VAL_STEP_LOG_ERRORS - 1] + self.spatial_loss_maps.append(log_spatial_losses) + # (B, N_log, num_grid_nodes) + + # Plot example predictions (on rank 0 only) + if self.trainer.is_global_zero: + self.plot_examples(batch, batch_idx, prediction=prediction) - if self.global_rank == 0 and self.trainer.datamodule.test_dataset.batch_index == batch_idx: - index_within_batch = self.trainer.datamodule.test_dataset.index_within_batch + def plot_examples(self, batch, batch_idx, prediction=None): + """ + Plot the first n_examples forecasts from batch + + batch: batch with data to plot corresponding forecasts for + n_examples: number of forecasts to plot + prediction: (B, pred_steps, num_grid_nodes, d_f), existing prediction. + Generate if None. + """ + if prediction is None: + prediction, target = self.common_step(batch) + + target = batch[1] + + if ( + self.global_rank == 0 + and self.trainer.datamodule.test_dataset.batch_index == batch_idx + ): + index_within_batch = ( + self.trainer.datamodule.test_dataset.index_within_batch + ) if not torch.is_tensor(index_within_batch): index_within_batch = torch.tensor( - index_within_batch, dtype=torch.int64, device=prediction.device) + index_within_batch, + dtype=torch.int64, + device=prediction.device, + ) prediction = prediction[index_within_batch] target = target[index_within_batch] @@ -377,203 +487,300 @@ def test_step(self, batch, batch_idx): prediction_rescaled = self.apply_constraints(prediction_rescaled) target_rescaled = target * self.data_std + self.data_mean - # BUG: this creates artifacts at border cells, improve logic! - if constants.smooth_boundaries: - # (pred_steps, N_grid, d_f) - - height, width = constants.grid_shape - prediction_permuted = prediction_rescaled.permute( - 0, 2, 1).reshape( - prediction_rescaled.size(0), - prediction_rescaled.size(2), - height, width) - - # Define the smoothing kernel for grouped convolution - num_groups = prediction_permuted.shape[1] - kernel_size = 3 - kernel = torch.ones((num_groups, 1, kernel_size, - kernel_size)) / (kernel_size ** 2) - kernel = kernel.to(self.device) - - # Use the updated kernel in the conv2d operation - prediction_smoothed = nn.functional.conv2d( - prediction_permuted, kernel, padding=1, groups=num_groups) - - # (pred_steps, N_grid, channels) - # Combine the height and width dimensions back into a single N_grid - # dimension - prediction_smoothed = prediction_smoothed.reshape( - prediction_smoothed.size(0), prediction_smoothed.size(1), -1) - - # Permute the dimensions to get back to the original order (pred_steps, - # N_grid, d_f) - prediction_smoothed = prediction_smoothed.permute(0, 2, 1) + if constants.SMOOTH_BOUNDARIES: + # BUG: this creates artifacts at border cells, improve logic! + prediction_rescaled = self.smooth_prediction_borders( + prediction_rescaled + ) - # Apply the mask to the smoothed prediction - prediction_rescaled = self.border_mask * prediction_smoothed + self.interior_mask * prediction_rescaled - - # Each slice is (pred_steps, N_grid, d_f) - # Iterate over variables + # Each slice is (pred_steps, N_grid, d_f) Iterate over variables for var_name, var_unit in self.selected_vars_units: # Retrieve the indices for the current variable var_indices = self.variable_indices[var_name] for lvl_i, var_i in enumerate(var_indices): # Calculate var_vrange for each index - lvl = constants.vertical_levels[lvl_i] + lvl = constants.VERTICAL_LEVELS[lvl_i] var_vmin = min( prediction_rescaled[:, :, var_i].min(), - target_rescaled[:, :, var_i].min()) + target_rescaled[:, :, var_i].min(), + ) var_vmax = max( prediction_rescaled[:, :, var_i].max(), - target_rescaled[:, :, var_i].max()) + target_rescaled[:, :, var_i].max(), + ) var_vrange = (var_vmin, var_vmax) # Iterate over time steps for t_i, (pred_t, target_t) in enumerate( - zip(prediction_rescaled, target_rescaled), start=1): + zip(prediction_rescaled, target_rescaled), start=1 + ): eval_datetime_obj = datetime.strptime( - constants.eval_datetime, '%Y%m%d%H') - current_datetime_obj = eval_datetime_obj + timedelta(hours=t_i) - current_datetime_str = current_datetime_obj.strftime('%Y%m%d%H') - title = f"{var_name} ({var_unit}), t={current_datetime_str}" + constants.EVAL_DATETIME, "%Y%m%d%H" + ) + current_datetime_obj = eval_datetime_obj + timedelta( + hours=t_i + ) + current_datetime_str = current_datetime_obj.strftime( + "%Y%m%d%H" + ) + title = ( + f"{var_name} ({var_unit}), t={current_datetime_str}" + ) var_fig = vis.plot_prediction( - pred_t[:, var_i], target_t[:, var_i], - self.interior_mask[:, 0], + pred_t[:, var_i], + target_t[:, var_i], title=title, - vrange=var_vrange + vrange=var_vrange, ) wandb.log( - {f"{var_name}_lvl_{lvl:02}_t_{current_datetime_str}": wandb.Image(var_fig)} + { + ( + f"{var_name}_lvl_{lvl:02}_t_" + f"{current_datetime_str}" + ): wandb.Image(var_fig) + } ) plt.close("all") - if constants.store_example_data: + if constants.STORE_EXAMPLE_DATA: # Save pred and target as .pt files torch.save( prediction_rescaled.cpu(), - os.path.join( - wandb.run.dir, - 'example_pred.pt')) + os.path.join(wandb.run.dir, "example_pred.pt"), + ) torch.save( target_rescaled.cpu(), - os.path.join( - wandb.run.dir, - 'example_target.pt')) + os.path.join(wandb.run.dir, "example_target.pt"), + ) - def on_test_epoch_end(self): + def smooth_prediction_borders(self, prediction_rescaled): """ - Compute test metrics and make plots at the end of test epoch. - Will gather stored tensors and perform plotting and logging on rank 0. + Smooths the prediction at the borders to avoid artifacts. + + Args: + prediction_rescaled (torch.Tensor): The rescaled prediction tensor. + + Returns: + torch.Tensor: The prediction tensor after smoothing the borders. """ + height, width = constants.GRID_SHAPE + prediction_permuted = prediction_rescaled.permute(0, 2, 1).reshape( + prediction_rescaled.size(0), + prediction_rescaled.size(2), + height, + width, + ) + + # Define the smoothing kernel for grouped convolution + num_groups = prediction_permuted.shape[1] + kernel_size = 3 + kernel = torch.ones((num_groups, 1, kernel_size, kernel_size)) / ( + kernel_size**2 + ) + kernel = kernel.to(self.device) + + # Use the updated kernel in the conv2d operation + # pylint: disable-next=not-callable + prediction_smoothed = nn.functional.conv2d( + prediction_permuted, kernel, padding=1, groups=num_groups + ) + + # Combine the height and width dimensions back into a single N_grid + # dimension + prediction_smoothed = prediction_smoothed.reshape( + prediction_smoothed.size(0), prediction_smoothed.size(1), -1 + ) + + # Permute the dimensions to get back to the original order + prediction_smoothed = prediction_smoothed.permute(0, 2, 1) + + # Apply the mask to the smoothed prediction + prediction_rescaled = ( + self.border_mask * prediction_smoothed + + self.interior_mask * prediction_rescaled + ) + + return prediction_rescaled + + def create_metric_log_dict(self, metric_tensor, prefix, metric_name): + """ + Put together a dict with everything to log for one metric. + Also saves plots as pdf and csv if using test prefix. - # Create error maps for RMSE and MAE + metric_tensor: (pred_steps, d_f), metric values per time and variable + prefix: string, prefix to use for logging + metric_name: string, name of the metric - test_mae_tensor = self.all_gather_cat( - torch.cat(self.test_maes, dim=0)) # (N_test, pred_steps, d_f) - test_mse_tensor = self.all_gather_cat( - torch.cat(self.test_mses, dim=0)) # (N_test, pred_steps, d_f) + Return: + log_dict: dict with everything to log for given metric + """ + log_dict = {} + metric_fig = vis.plot_error_map( + metric_tensor, self.data_mean, step_length=self.step_length + ) + full_log_name = f"{prefix}_{metric_name}" + log_dict[full_log_name] = wandb.Image(metric_fig) + + if prefix == "test": + # Save pdf + metric_fig.savefig( + os.path.join(wandb.run.dir, f"{full_log_name}.pdf") + ) + # Save errors also as csv + np.savetxt( + os.path.join(wandb.run.dir, f"{full_log_name}.csv"), + metric_tensor.cpu().numpy(), + delimiter=",", + ) + + # Check if metrics are watched, log exact values for specific vars + if full_log_name in constants.METRICS_WATCH: + for var_i, timesteps in constants.VAR_LEADS_METRICS_WATCH.items(): + var = constants.PARAM_NAMES_SHORT[var_i] + log_dict.update( + { + f"{full_log_name}_{var}_step_{step}": metric_tensor[ + step - 1, var_i + ] # 1-indexed in constants + for step in timesteps + } + ) + + return log_dict + + def aggregate_and_plot_metrics(self, metrics_dict, prefix): + """ + Aggregate and create error map plots for all metrics in metrics_dict - if self.trainer.is_global_zero: - test_mae_rescaled = torch.mean(test_mae_tensor, - dim=0) * self.data_std # (pred_steps, d_f) - - test_rmse_rescaled = torch.sqrt( - torch.mean( - test_mse_tensor, - dim=0)) * self.data_std # (pred_steps, d_f) - - # Create plots only for these instances - mae_fig = vis.plot_error_map( - test_mae_rescaled[self.val_step_log_errors - 1], - self.data_mean, - step_length=self.step_length) - rmse_fig = vis.plot_error_map( - test_rmse_rescaled[self.val_step_log_errors - 1], - self.data_mean, - step_length=self.step_length) - - wandb.log({ # Log png:s - "test_mae": wandb.Image(mae_fig), - "test_rmse": wandb.Image(rmse_fig), - }) - - # Save pdf:s - mae_fig.savefig(os.path.join(wandb.run.dir, "test_mae.pdf")) - rmse_fig.savefig(os.path.join(wandb.run.dir, "test_rmse.pdf")) - # Save errors also as csv:s - - np.savetxt(os.path.join(wandb.run.dir, "test_mae.csv"), - test_mae_rescaled.cpu().numpy(), delimiter=",") - np.savetxt(os.path.join(wandb.run.dir, "test_rmse.csv"), - test_rmse_rescaled.cpu().numpy(), delimiter=",") - - self.test_maes.clear() # Free memory - self.test_mses.clear() + metrics_dict: dictionary with metric_names and list of tensors + with step-evals. + prefix: string, prefix to use for logging + """ + log_dict = {} + for metric_name, metric_val_list in metrics_dict.items(): + metric_tensor = torch.cat(metric_val_list, dim=0) + + if self.trainer.is_global_zero: + metric_tensor_averaged = torch.mean(metric_tensor, dim=0) + # (pred_steps, d_f) + + # Take square root after all averaging to change MSE to RMSE + if "mse" in metric_name: + metric_tensor_averaged = torch.sqrt(metric_tensor_averaged) + metric_name = metric_name.replace("mse", "rmse") + + # Note: we here assume rescaling for all metrics is linear + metric_rescaled = metric_tensor_averaged * self.data_std + # (pred_steps, d_f) + log_dict.update( + self.create_metric_log_dict( + metric_rescaled, prefix, metric_name + ) + ) + + if self.trainer.is_global_zero and not self.trainer.sanity_checking: + wandb.log(log_dict) # Log all + plt.close("all") # Close all figs + + def on_test_epoch_end(self): + """ + Compute test metrics and make plots at the end of test epoch. + Will gather stored tensors and perform plotting and logging on rank 0. + """ + # Create error maps for all test metrics + self.aggregate_and_plot_metrics(self.test_metrics, prefix="test") # Plot spatial loss maps spatial_loss_tensor = self.all_gather_cat( - torch.cat( - self.spatial_loss_maps, - dim=0)) # (N_test, N_log, N_grid) - + torch.cat(self.spatial_loss_maps, dim=0) + ) # (N_test, N_log, num_grid_nodes) if self.trainer.is_global_zero: mean_spatial_loss = torch.mean( - spatial_loss_tensor, dim=0) # (N_log, N_grid) - - # Create plots and PDFs only for these instances - loss_map_figs = [vis.plot_spatial_error( - mean_spatial_loss[i], self.interior_mask[:, 0], - title=f"Test loss, t={val_step}, ({self.step_length*val_step} h)") - for i, val_step in enumerate(self.val_step_log_errors)] - - # Log all to same wandb key, sequentially + spatial_loss_tensor, dim=0 + ) # (N_log, num_grid_nodes) + + loss_map_figs = [ + vis.plot_spatial_error( + loss_map, + title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", + ) + for t_i, loss_map in zip( + constants.VAL_STEP_LOG_ERRORS, mean_spatial_loss + ) + ] + + # log all to same wandb key, sequentially for fig in loss_map_figs: wandb.log({"test_loss": wandb.Image(fig)}) - # Also make without title and save as PDF + # also make without title and save as pdf pdf_loss_map_figs = [ - vis.plot_spatial_error(loss_map, self.interior_mask[:, 0]) - for loss_map in mean_spatial_loss] + vis.plot_spatial_error(loss_map) + for loss_map in mean_spatial_loss + ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") os.makedirs(pdf_loss_maps_dir, exist_ok=True) - for t_i, fig in zip(constants.val_step_log_errors, pdf_loss_map_figs): + for t_i, fig in zip( + constants.VAL_STEP_LOG_ERRORS, pdf_loss_map_figs + ): fig.savefig(os.path.join(pdf_loss_maps_dir, f"loss_t{t_i}.pdf")) # save mean spatial loss as .pt file also - torch.save(mean_spatial_loss.cpu(), os.path.join( - wandb.run.dir, 'mean_spatial_loss.pt')) + torch.save( + mean_spatial_loss.cpu(), + os.path.join(wandb.run.dir, "mean_spatial_loss.pt"), + ) dir_path = f"{wandb.run.dir}/media/images" for var_name, _ in self.selected_vars_units: var_indices = self.variable_indices[var_name] - for lvl_i, var_i in enumerate(var_indices): + for lvl_i, _ in enumerate(var_indices): # Calculate var_vrange for each index - lvl = constants.vertical_levels[lvl_i] + lvl = constants.VERTICAL_LEVELS[lvl_i] # Get all the images for the current variable and index images = sorted( - glob.glob(f"{dir_path}/{var_name}_lvl_{lvl:02}_t_*.png")) + glob.glob(f"{dir_path}/{var_name}_lvl_{lvl:02}_t_*.png") + ) # Generate the GIF - with imageio.get_writer(f'{dir_path}/{var_name}_lvl_{lvl:02}.gif', mode='I', fps=1) as writer: + with imageio.get_writer( + f"{dir_path}/{var_name}_lvl_{lvl:02}.gif", + mode="I", + fps=1, + ) as writer: for filename in images: image = imageio.imread(filename) writer.append_data(image) self.spatial_loss_maps.clear() - -def on_load_checkpoint(self, ckpt): - loaded_state_dict = ckpt["state_dict"] - - if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict: - replace_keys = list(filter(lambda key: key.startswith("g2m_gnn.grid_mlp"), - loaded_state_dict.keys())) - for old_key in replace_keys: - new_key = old_key.replace("g2m_gnn.grid_mlp", "encoding_grid_mlp") - loaded_state_dict[new_key] = loaded_state_dict[old_key] - del loaded_state_dict[old_key] - - if not self.resume_opt_sched: - # Create new optimizer and scheduler instances instead of setting them to None - optimizers, lr_schedulers = self.configure_optimizers() - ckpt['optimizer_states'] = [opt.state_dict() for opt in optimizers] - ckpt['lr_schedulers'] = [sched.state_dict() for sched in lr_schedulers] + def on_load_checkpoint(self, checkpoint): + """ + Perform any changes to state dict before loading checkpoint + """ + loaded_state_dict = checkpoint["state_dict"] + + # Fix for loading older models after IneractionNet refactoring, where + # the grid MLP was moved outside the encoder InteractionNet class + if "g2m_gnn.grid_mlp.0.weight" in loaded_state_dict: + replace_keys = list( + filter( + lambda key: key.startswith("g2m_gnn.grid_mlp"), + loaded_state_dict.keys(), + ) + ) + for old_key in replace_keys: + new_key = old_key.replace( + "g2m_gnn.grid_mlp", "encoding_grid_mlp" + ) + loaded_state_dict[new_key] = loaded_state_dict[old_key] + del loaded_state_dict[old_key] + if not self.restore_opt: + # Create new optimizer and scheduler instances instead of setting + # them to None + optimizers, lr_schedulers = self.configure_optimizers() + checkpoint["optimizer_states"] = [ + opt.state_dict() for opt in optimizers + ] + checkpoint["lr_schedulers"] = [ + sched.state_dict() for sched in lr_schedulers + ] diff --git a/neural_lam/models/base_graph_model.py b/neural_lam/models/base_graph_model.py index 14de66e1..a8b300b0 100644 --- a/neural_lam/models/base_graph_model.py +++ b/neural_lam/models/base_graph_model.py @@ -1,6 +1,8 @@ +# Third-party import torch -from neural_lam import utils +# First-party +from neural_lam import constants, utils from neural_lam.interaction_net import InteractionNet from neural_lam.models.ar_model import ARModel @@ -15,7 +17,8 @@ def __init__(self, args): super().__init__(args) # Load graph with static features - # NOTE: (IMPORTANT!) mesh nodes MUST have the first N_mesh indices, + # NOTE: (IMPORTANT!) mesh nodes MUST have the first + # num_mesh_nodes indices, self.hierarchical, graph_ldict = utils.load_graph(args.graph) for name, attr_value in graph_ldict.items(): # Make BufferLists module members and register tensors as buffers @@ -25,59 +28,51 @@ def __init__(self, args): setattr(self, name, attr_value) # Specify dimensions of data - self.N_grid, grid_static_dim = self.grid_static_features.shape # 63784 = 268x238 - self.N_mesh, N_mesh_ignore = self.get_num_mesh() - if self.global_rank == 0: - print(f"Loaded graph with {self.N_grid + self.N_mesh} nodes " + - f"({self.N_grid} grid, {self.N_mesh} mesh)") + self.num_mesh_nodes, _ = self.get_num_mesh() + print( + f"Loaded graph with {self.num_grid_nodes + self.num_mesh_nodes} " + f"nodes ({self.num_grid_nodes} grid, {self.num_mesh_nodes} mesh)" + ) # grid_dim from data + static + batch_static - grid_dim = 2 * self.grid_state_dim + grid_static_dim + self.grid_forcing_dim +\ - self.batch_static_feature_dim # 2*81 + 4 + 0 + 0 = 166 self.g2m_edges, g2m_dim = self.g2m_features.shape self.m2g_edges, m2g_dim = self.m2g_features.shape # Define sub-models # Feature embedders for grid self.mlp_blueprint_end = [args.hidden_dim] * (args.hidden_layers + 1) - self.grid_embedder = utils.make_mlp([grid_dim] + - self.mlp_blueprint_end) - self.g2m_embedder = utils.make_mlp([g2m_dim] + - self.mlp_blueprint_end) - self.m2g_embedder = utils.make_mlp([m2g_dim] + - self.mlp_blueprint_end) + self.grid_embedder = utils.make_mlp( + [self.grid_dim] + self.mlp_blueprint_end + ) + self.g2m_embedder = utils.make_mlp([g2m_dim] + self.mlp_blueprint_end) + self.m2g_embedder = utils.make_mlp([m2g_dim] + self.mlp_blueprint_end) # GNNs # encoder - # TODO: g2m and m2g edge indices to device? self.g2m_gnn = InteractionNet( self.g2m_edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, - update_edges=False) - self.encoding_grid_mlp = utils.make_mlp([args.hidden_dim] - + self.mlp_blueprint_end) + update_edges=False, + ) + self.encoding_grid_mlp = utils.make_mlp( + [args.hidden_dim] + self.mlp_blueprint_end + ) # decoder self.m2g_gnn = InteractionNet( self.m2g_edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, - update_edges=False) + update_edges=False, + ) # Output mapping (hidden_dim -> output_dim) self.output_map = utils.make_mlp( - [args.hidden_dim] * (args.hidden_layers + 1) + [self.grid_state_dim], - layer_norm=False) # No layer norm on this one - - def setup(self, stage=None): - super().setup(stage) - self.g2m_features = self.g2m_features.to(self.device) - self.m2g_features = self.m2g_features.to(self.device) - self.m2m_features = self.m2m_features.to(self.device) - self.step_diff_mean = self.step_diff_mean.to(self.device) - self.step_diff_std = self.step_diff_std.to(self.device) - self.grid_static_features = self.grid_static_features.to(self.device) + [args.hidden_dim] * (args.hidden_layers + 1) + + [self.grid_output_dim], + layer_norm=False, + ) # No layer norm on this one def get_num_mesh(self): """ @@ -88,8 +83,8 @@ def get_num_mesh(self): def embedd_mesh_nodes(self): """ - Embedd static mesh features - Returns tensor of shape (N_mesh, d_h) + Embed static mesh features + Returns tensor of shape (num_mesh_nodes, d_h) """ raise NotImplementedError("embedd_mesh_nodes not implemented") @@ -98,45 +93,64 @@ def process_step(self, mesh_rep): Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps - mesh_rep: has shape (B, N_mesh, d_h) - Returns mesh_rep: (B, N_mesh, d_h) + mesh_rep: has shape (B, num_mesh_nodes, d_h) + Returns mesh_rep: (B, num_mesh_nodes, d_h) """ raise NotImplementedError("process_step not implemented") - def predict_step(self, prev_state, prev_prev_state): + def predict_step( + self, + prev_state, + prev_prev_state, + batch_static_features=None, + forcing=None, + ): """ Step state one step ahead using prediction model, X_{t-1}, X_t -> X_t+1 - prev_state: (B, N_grid, feature_dim), X_t - prev_prev_state: (B, N_grid, feature_dim), X_{t-1} - batch_static_features: (B, N_grid, batch_static_feature_dim) - forcing: (B, N_grid, forcing_dim) + prev_state: (B, num_grid_nodes, feature_dim), X_t + prev_prev_state: (B, num_grid_nodes, feature_dim), X_{t-1} + batch_static_features: (B, num_grid_nodes, batch_static_feature_dim) + forcing: (B, num_grid_nodes, forcing_dim), optional """ batch_size = prev_state.shape[0] - grid_features = torch.cat( - (prev_state, - prev_prev_state, - self.expand_to_batch( - self.grid_static_features, - batch_size)), - dim=-1) - - # Embedd all features - grid_emb = self.grid_embedder(grid_features) # (B, N_grid, d_h) + features_list = [ + prev_state, + prev_prev_state, + ] + + if ( + constants.BATCH_STATIC_FEATURE_DIM > 0 + and batch_static_features is not None + ): + features_list.append(batch_static_features) + if constants.GRID_FORCING_DIM > 0 and forcing is not None: + features_list.append(forcing) + features_list.append( + self.expand_to_batch(self.grid_static_features, batch_size) + ) + grid_features = torch.cat(features_list, dim=-1) + + # Embed all features + grid_emb = self.grid_embedder(grid_features) # (B, num_grid_nodes, d_h) g2m_emb = self.g2m_embedder(self.g2m_features) # (M_g2m, d_h) m2g_emb = self.m2g_embedder(self.m2g_features) # (M_m2g, d_h) mesh_emb = self.embedd_mesh_nodes() # Map from grid to mesh mesh_emb_expanded = self.expand_to_batch( - mesh_emb, batch_size) # (B, N_mesh, d_h) + mesh_emb, batch_size + ) # (B, num_mesh_nodes, d_h) g2m_emb_expanded = self.expand_to_batch(g2m_emb, batch_size) # This also splits representation into grid and mesh - mesh_rep = self.g2m_gnn(grid_emb, mesh_emb_expanded, - g2m_emb_expanded) # (B, N_mesh, d_h) + mesh_rep = self.g2m_gnn( + grid_emb, mesh_emb_expanded, g2m_emb_expanded + ) # (B, num_mesh_nodes, d_h) # Also MLP with residual for grid representation - grid_rep = grid_emb + self.encoding_grid_mlp(grid_emb) # (B, N_grid, d_h) + grid_rep = grid_emb + self.encoding_grid_mlp( + grid_emb + ) # (B, num_grid_nodes, d_h) # Run processor step mesh_rep = self.process_step(mesh_rep) @@ -144,15 +158,30 @@ def predict_step(self, prev_state, prev_prev_state): # Map back from mesh to grid m2g_emb_expanded = self.expand_to_batch(m2g_emb, batch_size) grid_rep = self.m2g_gnn( - mesh_rep, - grid_rep, - m2g_emb_expanded) # (B, N_grid, d_h) + mesh_rep, grid_rep, m2g_emb_expanded + ) # (B, num_grid_nodes, d_h) # Map to output dimension, only for grid - net_output = self.output_map(grid_rep) # (B, N_grid, d_f) + net_output = self.output_map( + grid_rep + ) # (B, num_grid_nodes, d_grid_out) + + if self.output_std: + pred_delta_mean, pred_std_raw = net_output.chunk( + 2, dim=-1 + ) # both (B, num_grid_nodes, d_f) + # Note: The predicted std. is not scaled in any way here + # linter for some reason does not think softplus is callable + # pylint: disable-next=not-callable + pred_std = torch.nn.functional.softplus(pred_std_raw) + else: + pred_delta_mean = net_output + pred_std = None # Rescale with one-step difference statistics - rescaled_net_output = net_output * self.step_diff_std + self.step_diff_mean + rescaled_delta_mean = ( + pred_delta_mean * self.step_diff_std + self.step_diff_mean + ) # Residual connection for full state - return prev_state + rescaled_net_output + return prev_state + rescaled_delta_mean, pred_std diff --git a/neural_lam/models/base_hi_graph_model.py b/neural_lam/models/base_hi_graph_model.py index b624380c..5772af6f 100644 --- a/neural_lam/models/base_hi_graph_model.py +++ b/neural_lam/models/base_hi_graph_model.py @@ -1,6 +1,7 @@ -import torch +# Third-party from torch import nn +# First-party from neural_lam import utils from neural_lam.interaction_net import InteractionNet from neural_lam.models.base_graph_model import BaseGraphModel @@ -16,24 +17,27 @@ def __init__(self, args): # Track number of nodes, edges on each level # Flatten lists for efficient embedding - self.N_levels = len(self.mesh_static_features) + self.num_levels = len(self.mesh_static_features) # Number of mesh nodes at each level - self.N_mesh_levels = [mesh_feat.shape[0] for mesh_feat in - self.mesh_static_features] # Needs as python list for later - N_mesh_levels_torch = torch.tensor(self.N_mesh_levels) + self.level_mesh_sizes = [ + mesh_feat.shape[0] for mesh_feat in self.mesh_static_features + ] # Needs as python list for later # Print some useful info - print("Loaded hierachical graph with structure:") - for ll, N_level in enumerate(self.N_mesh_levels): - same_level_edges = self.m2m_features[ll].shape[0] - print(f"level {ll} - {N_level} nodes, {same_level_edges} same-level edges") - - if ll < (self.N_levels - 1): - up_edges = self.mesh_up_features[ll].shape[0] - down_edges = self.mesh_down_features[ll].shape[0] - print(f" {ll}<->{ll+1} - {up_edges} up edges, {down_edges} down edges") - + print("Loaded hierarchical graph with structure:") + for level_index, level_mesh_size in enumerate(self.level_mesh_sizes): + same_level_edges = self.m2m_features[level_index].shape[0] + print( + f"level {level_index} - {level_mesh_size} nodes, " + f"{same_level_edges} same-level edges" + ) + + if level_index < (self.num_levels - 1): + up_edges = self.mesh_up_features[level_index].shape[0] + down_edges = self.mesh_down_features[level_index].shape[0] + print(f" {level_index}<->{level_index + 1}") + print(f" - {up_edges} up edges, {down_edges} down edges") # Embedders # Assume all levels have same static feature dimensionality mesh_dim = self.mesh_static_features[0].shape[1] @@ -43,44 +47,75 @@ def __init__(self, args): # Separate mesh node embedders for each level self.mesh_embedders = nn.ModuleList( - [utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) - for _ in range(self.N_levels)]) + [ + utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels) + ] + ) self.mesh_same_embedders = nn.ModuleList( - [utils.make_mlp([mesh_same_dim] + self.mlp_blueprint_end) - for _ in range(self.N_levels)]) + [ + utils.make_mlp([mesh_same_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels) + ] + ) self.mesh_up_embedders = nn.ModuleList( - [utils.make_mlp([mesh_up_dim] + self.mlp_blueprint_end) - for _ in range(self.N_levels - 1)]) + [ + utils.make_mlp([mesh_up_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels - 1) + ] + ) self.mesh_down_embedders = nn.ModuleList( - [utils.make_mlp([mesh_down_dim] + self.mlp_blueprint_end) - for _ in range(self.N_levels - 1)]) + [ + utils.make_mlp([mesh_down_dim] + self.mlp_blueprint_end) + for _ in range(self.num_levels - 1) + ] + ) # Instantiate GNNs # Init GNNs - self.mesh_init_gnns = nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.mesh_up_edge_index]) + self.mesh_init_gnns = nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.mesh_up_edge_index + ] + ) # Read out GNNs - self.mesh_read_gnns = nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers, - update_edges=False) - for edge_index in self.mesh_down_edge_index]) + self.mesh_read_gnns = nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + update_edges=False, + ) + for edge_index in self.mesh_down_edge_index + ] + ) def get_num_mesh(self): """ Compute number of mesh nodes from loaded features, and number of mesh nodes that should be ignored in encoding/decoding """ - N_mesh = sum(node_feat.shape[0] for node_feat in self.mesh_static_features) - N_mesh_ignore = N_mesh - self.mesh_static_features[0].shape[0] - return N_mesh, N_mesh_ignore + num_mesh_nodes = sum( + node_feat.shape[0] for node_feat in self.mesh_static_features + ) + num_mesh_nodes_ignore = ( + num_mesh_nodes - self.mesh_static_features[0].shape[0] + ) + return num_mesh_nodes, num_mesh_nodes_ignore def embedd_mesh_nodes(self): """ - Embedd static mesh features - This embedds only bottom level, rest is done at beginning of processing step - Returns tensor of shape (N_mesh[0], d_h) + Embed static mesh features + This embeds only bottom level, rest is done at beginning of + processing step + Returns tensor of shape (num_mesh_nodes[0], d_h) """ return self.mesh_embedders[0](self.mesh_static_features[0]) @@ -89,90 +124,106 @@ def process_step(self, mesh_rep): Process step of embedd-process-decode framework Processes the representation on the mesh, possible in multiple steps - mesh_rep: has shape (B, N_mesh, d_h) - Returns mesh_rep: (B, N_mesh, d_h) + mesh_rep: has shape (B, num_mesh_nodes, d_h) + Returns mesh_rep: (B, num_mesh_nodes, d_h) """ batch_size = mesh_rep.shape[0] - # EMBEDD REMAINING MESH NODES (levels >= 1) - + # EMBED REMAINING MESH NODES (levels >= 1) - # Create list of mesh node representations for each level, - # each of size (B, N_mesh[l], d_h) - mesh_rep_levels = [mesh_rep] + [self.expand_to_batch( - emb(node_static_features), batch_size) for - emb, node_static_features in - zip(list(self.mesh_embedders)[1:], list(self.mesh_static_features)[1:])] - - # - EMBEDD EDGES - - # Embedd edges, expand with batch dimension + # each of size (B, num_mesh_nodes[l], d_h) + mesh_rep_levels = [mesh_rep] + [ + self.expand_to_batch(emb(node_static_features), batch_size) + for emb, node_static_features in zip( + list(self.mesh_embedders)[1:], + list(self.mesh_static_features)[1:], + ) + ] + + # - EMBED EDGES - + # Embed edges, expand with batch dimension mesh_same_rep = [ - self.expand_to_batch( - emb(edge_feat), - batch_size) for emb, - edge_feat in zip( - self.mesh_same_embedders, - self.m2m_features)] + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_same_embedders, self.m2m_features + ) + ] mesh_up_rep = [ - self.expand_to_batch( - emb(edge_feat), - batch_size) for emb, - edge_feat in zip( - self.mesh_up_embedders, - self.mesh_up_features)] + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_up_embedders, self.mesh_up_features + ) + ] mesh_down_rep = [ - self.expand_to_batch( - emb(edge_feat), - batch_size) for emb, - edge_feat in zip( - self.mesh_down_embedders, - self.mesh_down_features)] + self.expand_to_batch(emb(edge_feat), batch_size) + for emb, edge_feat in zip( + self.mesh_down_embedders, self.mesh_down_features + ) + ] # - MESH INIT. - # Let level_l go from 1 to L for level_l, gnn in enumerate(self.mesh_init_gnns, start=1): # Extract representations - send_node_rep = mesh_rep_levels[level_l - 1] # (B, N_mesh[l-1], d_h) - rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) + send_node_rep = mesh_rep_levels[ + level_l - 1 + ] # (B, num_mesh_nodes[l-1], d_h) + rec_node_rep = mesh_rep_levels[ + level_l + ] # (B, num_mesh_nodes[l], d_h) edge_rep = mesh_up_rep[level_l - 1] # Apply GNN - new_node_rep, new_edge_rep = gnn(send_node_rep, rec_node_rep, edge_rep) + new_node_rep, new_edge_rep = gnn( + send_node_rep, rec_node_rep, edge_rep + ) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = new_node_rep # (B, N_mesh[l], d_h) + mesh_rep_levels[level_l] = ( + new_node_rep # (B, num_mesh_nodes[l], d_h) + ) mesh_up_rep[level_l - 1] = new_edge_rep # (B, M_up[l-1], d_h) # - PROCESSOR - mesh_rep_levels, _, _, mesh_down_rep = self.hi_processor_step( - mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep) + mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ) # - MESH READ OUT. - # Let level_l go from L-1 to 0 for level_l, gnn in zip( - range(self.N_levels - 2, -1, -1), - reversed(self.mesh_read_gnns)): + range(self.num_levels - 2, -1, -1), reversed(self.mesh_read_gnns) + ): # Extract representations - send_node_rep = mesh_rep_levels[level_l + 1] # (B, N_mesh[l+1], d_h) - rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) + send_node_rep = mesh_rep_levels[ + level_l + 1 + ] # (B, num_mesh_nodes[l+1], d_h) + rec_node_rep = mesh_rep_levels[ + level_l + ] # (B, num_mesh_nodes[l], d_h) edge_rep = mesh_down_rep[level_l] # Apply GNN new_node_rep = gnn(send_node_rep, rec_node_rep, edge_rep) # Update node and edge vectors in lists - mesh_rep_levels[level_l] = new_node_rep # (B, N_mesh[l], d_h) + mesh_rep_levels[level_l] = ( + new_node_rep # (B, num_mesh_nodes[l], d_h) + ) # Return only bottom level representation - return mesh_rep_levels[0] # (B, N_mesh[0], d_h) + return mesh_rep_levels[0] # (B, num_mesh_nodes[0], d_h) - def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, - mesh_down_rep): + def hi_processor_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ): """ Internal processor step of hierarchical graph models. Between mesh init and read out. Each input is list with representations, each with shape - mesh_rep_levels: (B, N_mesh[l], d_h) + mesh_rep_levels: (B, num_mesh_nodes[l], d_h) mesh_same_rep: (B, M_same[l], d_h) mesh_up_rep: (B, M_up[l -> l+1], d_h) mesh_down_rep: (B, M_down[l <- l+1], d_h) diff --git a/neural_lam/models/graph_lam.py b/neural_lam/models/graph_lam.py index 38bef032..f767fba0 100644 --- a/neural_lam/models/graph_lam.py +++ b/neural_lam/models/graph_lam.py @@ -1,6 +1,7 @@ -import torch +# Third-party import torch_geometric as pyg +# First-party from neural_lam import utils from neural_lam.interaction_net import InteractionNet from neural_lam.models.base_graph_model import BaseGraphModel @@ -8,47 +9,50 @@ class GraphLAM(BaseGraphModel): """ - Full graph-based LAM model that can be used with different (non-hierarchical )graphs. - Mainly based on GraphCast, but the model from Keisler (2022) almost identical. - Used for GC-LAM and L1-LAM in Oskarsson et al. (2023). + Full graph-based LAM model that can be used with different + (non-hierarchical )graphs. Mainly based on GraphCast, but the model from + Keisler (2022) is almost identical. Used for GC-LAM and L1-LAM in + Oskarsson et al. (2023). """ def __init__(self, args): super().__init__(args) - assert not self.hierarchical, "GraphLAM does not use a hierarchical mesh graph" + assert ( + not self.hierarchical + ), "GraphLAM does not use a hierarchical mesh graph" # grid_dim from data + static + batch_static mesh_dim = self.mesh_static_features.shape[1] m2m_edges, m2m_dim = self.m2m_features.shape - if torch.distributed.get_rank == 0: - print(f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " - f"m2g={self.m2g_edges}") + print( + f"Edges in subgraphs: m2m={m2m_edges}, g2m={self.g2m_edges}, " + f"m2g={self.m2g_edges}" + ) # Define sub-models # Feature embedders for mesh - self.mesh_embedder = utils.make_mlp([mesh_dim] + - self.mlp_blueprint_end) - self.m2m_embedder = utils.make_mlp([m2m_dim] + - self.mlp_blueprint_end) - self.args = args + self.mesh_embedder = utils.make_mlp([mesh_dim] + self.mlp_blueprint_end) + self.m2m_embedder = utils.make_mlp([m2m_dim] + self.mlp_blueprint_end) - def setup(self, stage=None): - super().setup(stage) - # TODO: m2m, to device? # GNNs # processor processor_nets = [ InteractionNet( - self.m2m_edge_index, self.args.hidden_dim, - hidden_layers=self.args.hidden_layers, aggr=self.args.mesh_aggr) - for _ in range(self.args.processor_layers)] - self.processor = pyg.nn.Sequential("mesh_rep, edge_rep", [ - (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") - for net in processor_nets]) - # Move the entire processor to the device - for net in self.processor: - net.to(self.device) + self.m2m_edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + aggr=args.mesh_aggr, + ) + for _ in range(args.processor_layers) + ] + self.processor = pyg.nn.Sequential( + "mesh_rep, edge_rep", + [ + (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") + for net in processor_nets + ], + ) def get_num_mesh(self): """ @@ -59,11 +63,10 @@ def get_num_mesh(self): def embedd_mesh_nodes(self): """ - Embedd static mesh features + Embed static mesh features Returns tensor of shape (N_mesh, d_h) """ - return self.mesh_embedder( - self.mesh_static_features.to(self.device)) # (N_mesh, d_h) + return self.mesh_embedder(self.mesh_static_features) # (N_mesh, d_h) def process_step(self, mesh_rep): """ @@ -73,10 +76,14 @@ def process_step(self, mesh_rep): mesh_rep: has shape (B, N_mesh, d_h) Returns mesh_rep: (B, N_mesh, d_h) """ - # Embedd m2m here first + # Embed m2m here first batch_size = mesh_rep.shape[0] m2m_emb = self.m2m_embedder(self.m2m_features) # (M_mesh, d_h) - m2m_emb_expanded = self.expand_to_batch(m2m_emb, batch_size) # (B, M_mesh, d_h) + m2m_emb_expanded = self.expand_to_batch( + m2m_emb, batch_size + ) # (B, M_mesh, d_h) - mesh_rep, _ = self.processor(mesh_rep, m2m_emb_expanded) # (B, N_mesh, d_h) + mesh_rep, _ = self.processor( + mesh_rep, m2m_emb_expanded + ) # (B, N_mesh, d_h) return mesh_rep diff --git a/neural_lam/models/hi_lam.py b/neural_lam/models/hi_lam.py index 1d9a840b..4d7eb94c 100644 --- a/neural_lam/models/hi_lam.py +++ b/neural_lam/models/hi_lam.py @@ -1,13 +1,15 @@ +# Third-party from torch import nn +# First-party from neural_lam.interaction_net import InteractionNet from neural_lam.models.base_hi_graph_model import BaseHiGraphModel class HiLAM(BaseHiGraphModel): """ - Hierarchical graph model with message passing that goes sequentially down and up - the hierarchy during processing. + Hierarchical graph model with message passing that goes sequentially down + and up the hierarchy during processing. The Hi-LAM model from Oskarsson et al. (2023) """ @@ -15,107 +17,152 @@ def __init__(self, args): super().__init__(args) # Make down GNNs, both for down edges and same level - self.mesh_down_gnns = nn.ModuleList([self.make_down_gnns(args) for _ in range( - args.processor_layers)]) # Nested lists (proc_steps, N_levels-1) - self.mesh_down_same_gnns = nn.ModuleList([self.make_same_gnns(args) for _ in range( - args.processor_layers)]) # Nested lists (proc_steps, N_levels) + self.mesh_down_gnns = nn.ModuleList( + [self.make_down_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels-1) + self.mesh_down_same_gnns = nn.ModuleList( + [self.make_same_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels) # Make up GNNs, both for up edges and same level - self.mesh_up_gnns = nn.ModuleList([self.make_up_gnns(args) for _ in range( - args.processor_layers)]) # Nested lists (proc_steps, N_levels-1) - self.mesh_up_same_gnns = nn.ModuleList([self.make_same_gnns(args) for _ in range( - args.processor_layers)]) # Nested lists (proc_steps, N_levels) + self.mesh_up_gnns = nn.ModuleList( + [self.make_up_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels-1) + self.mesh_up_same_gnns = nn.ModuleList( + [self.make_same_gnns(args) for _ in range(args.processor_layers)] + ) # Nested lists (proc_steps, num_levels) def make_same_gnns(self, args): """ Make intra-level GNNs. """ - return nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.m2m_edge_index]) + return nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.m2m_edge_index + ] + ) def make_up_gnns(self, args): """ Make GNNs for processing steps up through the hierarchy. """ - return nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.mesh_up_edge_index]) + return nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.mesh_up_edge_index + ] + ) def make_down_gnns(self, args): """ Make GNNs for processing steps down through the hierarchy. """ - return nn.ModuleList([InteractionNet( - edge_index, args.hidden_dim, hidden_layers=args.hidden_layers) - for edge_index in self.mesh_down_edge_index]) - - def mesh_down_step(self, mesh_rep_levels, mesh_same_rep, mesh_down_rep, down_gnns, - same_gnns): + return nn.ModuleList( + [ + InteractionNet( + edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + ) + for edge_index in self.mesh_down_edge_index + ] + ) + + def mesh_down_step( + self, + mesh_rep_levels, + mesh_same_rep, + mesh_down_rep, + down_gnns, + same_gnns, + ): """ - Run down-part of vertical processing, sequentially alternating between processing - using down edges and same-level edges. + Run down-part of vertical processing, sequentially alternating between + processing using down edges and same-level edges. """ - # Run same level processing on level L mesh_rep_levels[-1], mesh_same_rep[-1] = same_gnns[-1]( - mesh_rep_levels[-1], mesh_rep_levels[-1], mesh_same_rep[-1]) + mesh_rep_levels[-1], mesh_rep_levels[-1], mesh_same_rep[-1] + ) # Let level_l go from L-1 to 0 for level_l, down_gnn, same_gnn in zip( - range(self.N_levels - 2, -1, -1), - reversed(down_gnns), reversed(same_gnns[:-1])): + range(self.num_levels - 2, -1, -1), + reversed(down_gnns), + reversed(same_gnns[:-1]), + ): # Extract representations - send_node_rep = mesh_rep_levels[level_l + 1] # (B, N_mesh[l+1], d_h) + send_node_rep = mesh_rep_levels[ + level_l + 1 + ] # (B, N_mesh[l+1], d_h) rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) down_edge_rep = mesh_down_rep[level_l] same_edge_rep = mesh_same_rep[level_l] # Apply down GNN - new_node_rep, mesh_down_rep[level_l] = down_gnn(send_node_rep, rec_node_rep, - down_edge_rep) + new_node_rep, mesh_down_rep[level_l] = down_gnn( + send_node_rep, rec_node_rep, down_edge_rep + ) # Run same level processing on level l mesh_rep_levels[level_l], mesh_same_rep[level_l] = same_gnn( - new_node_rep, new_node_rep, same_edge_rep) + new_node_rep, new_node_rep, same_edge_rep + ) # (B, N_mesh[l], d_h) and (B, M_same[l], d_h) return mesh_rep_levels, mesh_same_rep, mesh_down_rep - def mesh_up_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, - same_gnns): + def mesh_up_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, same_gnns + ): """ - Run up-part of vertical processing, sequentially alternating between processing - using up edges and same-level edges. + Run up-part of vertical processing, sequentially alternating between + processing using up edges and same-level edges. """ # Run same level processing on level 0 mesh_rep_levels[0], mesh_same_rep[0] = same_gnns[0]( - mesh_rep_levels[0], mesh_rep_levels[0], mesh_same_rep[0]) + mesh_rep_levels[0], mesh_rep_levels[0], mesh_same_rep[0] + ) # Let level_l go from 1 to L - for level_l, (up_gnn, same_gnn) in enumerate(zip(up_gnns, same_gnns[1:]), - start=1): + for level_l, (up_gnn, same_gnn) in enumerate( + zip(up_gnns, same_gnns[1:]), start=1 + ): # Extract representations - send_node_rep = mesh_rep_levels[level_l - 1] # (B, N_mesh[l-1], d_h) + send_node_rep = mesh_rep_levels[ + level_l - 1 + ] # (B, N_mesh[l-1], d_h) rec_node_rep = mesh_rep_levels[level_l] # (B, N_mesh[l], d_h) up_edge_rep = mesh_up_rep[level_l - 1] same_edge_rep = mesh_same_rep[level_l] # Apply up GNN - new_node_rep, mesh_up_rep[level_l - 1] = up_gnn(send_node_rep, rec_node_rep, - up_edge_rep) + new_node_rep, mesh_up_rep[level_l - 1] = up_gnn( + send_node_rep, rec_node_rep, up_edge_rep + ) # (B, N_mesh[l], d_h) and (B, M_up[l-1], d_h) # Run same level processing on level l mesh_rep_levels[level_l], mesh_same_rep[level_l] = same_gnn( - new_node_rep, new_node_rep, same_edge_rep) + new_node_rep, new_node_rep, same_edge_rep + ) # (B, N_mesh[l], d_h) and (B, M_same[l], d_h) return mesh_rep_levels, mesh_same_rep, mesh_up_rep - def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, - mesh_down_rep): + def hi_processor_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ): """ Internal processor step of hierarchical graph models. Between mesh init and read out. @@ -130,16 +177,28 @@ def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, Returns same lists """ for down_gnns, down_same_gnns, up_gnns, up_same_gnns in zip( - self.mesh_down_gnns, self.mesh_down_same_gnns, self.mesh_up_gnns, self.mesh_up_same_gnns): + self.mesh_down_gnns, + self.mesh_down_same_gnns, + self.mesh_up_gnns, + self.mesh_up_same_gnns, + ): # Down mesh_rep_levels, mesh_same_rep, mesh_down_rep = self.mesh_down_step( - mesh_rep_levels, mesh_same_rep, mesh_down_rep, down_gnns, - down_same_gnns) + mesh_rep_levels, + mesh_same_rep, + mesh_down_rep, + down_gnns, + down_same_gnns, + ) # Up mesh_rep_levels, mesh_same_rep, mesh_up_rep = self.mesh_up_step( - mesh_rep_levels, mesh_same_rep, mesh_up_rep, up_gnns, - up_same_gnns) + mesh_rep_levels, + mesh_same_rep, + mesh_up_rep, + up_gnns, + up_same_gnns, + ) # Note: We return all, even though only down edges really are used later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/models/hi_lam_parallel.py b/neural_lam/models/hi_lam_parallel.py index 10308309..740824e1 100644 --- a/neural_lam/models/hi_lam_parallel.py +++ b/neural_lam/models/hi_lam_parallel.py @@ -1,42 +1,58 @@ +# Third-party import torch import torch_geometric as pyg +# First-party from neural_lam.interaction_net import InteractionNet from neural_lam.models.base_hi_graph_model import BaseHiGraphModel class HiLAMParallel(BaseHiGraphModel): """ - Version of HiLAM where all message passing in the hierarchical mesh (up, down, - inter-level) is ran in paralell. + Version of HiLAM where all message passing in the hierarchical mesh (up, + down, inter-level) is ran in parallel. - This is a somewhat simpler alternative to the sequential message passing of Hi-LAM. + This is a somewhat simpler alternative to the sequential message passing + of Hi-LAM. """ def __init__(self, args): super().__init__(args) # Processor GNNs - # Create the complete total edge_index combining all edges for processing - total_edge_index_list = list(self.m2m_edge_index) +\ - list(self.mesh_up_edge_index) + list(self.mesh_down_edge_index) + # Create the complete edge_index combining all edges for processing + total_edge_index_list = ( + list(self.m2m_edge_index) + + list(self.mesh_up_edge_index) + + list(self.mesh_down_edge_index) + ) total_edge_index = torch.cat(total_edge_index_list, dim=1) self.edge_split_sections = [ei.shape[1] for ei in total_edge_index_list] if args.processor_layers == 0: - self.processor = (lambda x, edge_attr: (x, edge_attr)) + self.processor = lambda x, edge_attr: (x, edge_attr) else: - processor_nets = [InteractionNet(total_edge_index, args.hidden_dim, - hidden_layers=args.hidden_layers, - edge_chunk_sizes=self.edge_split_sections, - aggr_chunk_sizes=self.N_mesh_levels) - for _ in range(args.processor_layers)] - self.processor = pyg.nn.Sequential("mesh_rep, edge_rep", [ - (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") - for net in processor_nets]) - - def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, - mesh_down_rep): + processor_nets = [ + InteractionNet( + total_edge_index, + args.hidden_dim, + hidden_layers=args.hidden_layers, + edge_chunk_sizes=self.edge_split_sections, + aggr_chunk_sizes=self.level_mesh_sizes, + ) + for _ in range(args.processor_layers) + ] + self.processor = pyg.nn.Sequential( + "mesh_rep, edge_rep", + [ + (net, "mesh_rep, mesh_rep, edge_rep -> mesh_rep, edge_rep") + for net in processor_nets + ], + ) + + def hi_processor_step( + self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep + ): """ Internal processor step of hierarchical graph models. Between mesh init and read out. @@ -53,22 +69,28 @@ def hi_processor_step(self, mesh_rep_levels, mesh_same_rep, mesh_up_rep, # First join all node and edge representations to single tensors mesh_rep = torch.cat(mesh_rep_levels, dim=1) # (B, N_mesh, d_h) - mesh_edge_rep = torch.cat(mesh_same_rep + mesh_up_rep + mesh_down_rep, - axis=1) # (B, M_mesh, d_h) + mesh_edge_rep = torch.cat( + mesh_same_rep + mesh_up_rep + mesh_down_rep, axis=1 + ) # (B, M_mesh, d_h) # Here, update mesh_*_rep and mesh_rep mesh_rep, mesh_edge_rep = self.processor(mesh_rep, mesh_edge_rep) # Split up again for read-out step - mesh_rep_levels = list(torch.split(mesh_rep, self.N_mesh_levels, dim=1)) - mesh_edge_rep_sections = torch.split(mesh_edge_rep, self.edge_split_sections, - dim=1) - - mesh_same_rep = mesh_edge_rep_sections[:self.N_levels] + mesh_rep_levels = list( + torch.split(mesh_rep, self.level_mesh_sizes, dim=1) + ) + mesh_edge_rep_sections = torch.split( + mesh_edge_rep, self.edge_split_sections, dim=1 + ) + + mesh_same_rep = mesh_edge_rep_sections[: self.num_levels] mesh_up_rep = mesh_edge_rep_sections[ - self.N_levels:self.N_levels + (self.N_levels - 1)] + self.num_levels : self.num_levels + (self.num_levels - 1) + ] mesh_down_rep = mesh_edge_rep_sections[ - self.N_levels + (self.N_levels - 1):] # Last are down edges + self.num_levels + (self.num_levels - 1) : + ] # Last are down edges # Note: We return all, even though only down edges really are used later return mesh_rep_levels, mesh_same_rep, mesh_up_rep, mesh_down_rep diff --git a/neural_lam/rotate_grid.py b/neural_lam/rotate_grid.py index 95081fab..23294486 100644 --- a/neural_lam/rotate_grid.py +++ b/neural_lam/rotate_grid.py @@ -1,7 +1,9 @@ """unrotate rotated pole coordinates to geographical lat/lon""" +# Third-party import numpy as np +# First-party from neural_lam import constants @@ -37,15 +39,17 @@ def unrot_lon(rotlon, rotlat, pollon, pollat): c2 = np.cos(np.radians(pollon)) # subresults - tmp1 = s2 * (-s1 * np.cos(rlo) * np.cos(rla) + c1 * - np.sin(rla)) - c2 * np.sin(rlo) * np.cos(rla) - tmp2 = c2 * (-s1 * np.cos(rlo) * np.cos(rla) + c1 * - np.sin(rla)) + s2 * np.sin(rlo) * np.cos(rla) + tmp1 = s2 * ( + -s1 * np.cos(rlo) * np.cos(rla) + c1 * np.sin(rla) + ) - c2 * np.sin(rlo) * np.cos(rla) + tmp2 = c2 * ( + -s1 * np.cos(rlo) * np.cos(rla) + c1 * np.sin(rla) + ) + s2 * np.sin(rlo) * np.cos(rla) return np.degrees(np.arctan(tmp1 / tmp2)) -def unrot_lat(rotlat, rotlon, pollon, pollat): +def unrot_lat(rotlat, rotlon, pollat): """Transform rotated latitude to latitude. Parameters @@ -81,9 +85,10 @@ def unrot_lat(rotlat, rotlon, pollon, pollat): def unrotate_latlon(data): + """Unrotate lat/lon coordinates from rotated pole grid.""" xx, yy = np.meshgrid(data.x_1.values, data.y_1.values) # unrotate lon/lat - lon = unrot_lon(xx, yy, constants.pollon, constants.pollat) - lat = unrot_lat(yy, xx, constants.pollon, constants.pollat) + lon = unrot_lon(xx, yy, constants.POLLON, constants.POLLAT) + lat = unrot_lat(yy, xx, constants.POLLAT) return lon.T, lat.T diff --git a/neural_lam/utils.py b/neural_lam/utils.py index a288f77a..e21fe083 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,43 +1,66 @@ +# Standard library import os +# Third-party import numpy as np import torch -import torch.nn as nn from pytorch_lightning.utilities import rank_zero_only +from torch import nn from tueplots import bundles, figsizes +# First-party from neural_lam import constants def load_dataset_stats(dataset_name, device="cpu"): + """ + Load arrays with stored dataset statistics from pre-processing + """ static_dir_path = os.path.join("data", dataset_name, "static") def loads_file(fn): - return torch.load(os.path.join(static_dir_path, fn), map_location=device) + return torch.load( + os.path.join(static_dir_path, fn), map_location=device + ) data_mean = loads_file("parameter_mean.pt") # (d_features,) data_std = loads_file("parameter_std.pt") # (d_features,) - return { - "data_mean": data_mean, - "data_std": data_std, - # "flux_mean": flux_mean, - # "flux_std": flux_std, - } + if constants.GRID_FORCING_DIM > 0: + flux_stats = loads_file("flux_stats.pt") # (2,) + flux_mean, flux_std = flux_stats + + return { + "data_mean": data_mean, + "data_std": data_std, + "flux_mean": flux_mean, + "flux_std": flux_std, + } + return {"data_mean": data_mean, "data_std": data_std} def load_static_data(dataset_name, device="cpu"): + """ + Load static files related to dataset + """ static_dir_path = os.path.join("data", dataset_name, "static") def loads_file(fn): - return torch.load(os.path.join(static_dir_path, fn), map_location=device) + return torch.load( + os.path.join(static_dir_path, fn), map_location=device + ) # Load border mask, 1. if node is part of border, else 0. border_mask_np = np.load(os.path.join(static_dir_path, "border_mask.npy")) - border_mask = torch.tensor(border_mask_np, dtype=torch.float32, - device=device).flatten(0, 1).unsqueeze(1) # (N_grid, 1) + border_mask = ( + torch.tensor(border_mask_np, dtype=torch.float32, device=device) + .flatten(0, 1) + .unsqueeze(1) + ) # (N_grid, 1) - grid_static_features = loads_file("grid_features.pt") # (N_grid, d_grid_static) + grid_static_features = loads_file( + "grid_features.pt" + ) # (N_grid, d_grid_static) # Load step diff stats step_diff_mean = loads_file("diff_mean.pt") # (d_f,) @@ -49,12 +72,10 @@ def loads_file(fn): # Load loss weighting vectors param_weights = torch.tensor( - np.load( - os.path.join( - static_dir_path, - "parameter_weights.npy")), + np.load(os.path.join(static_dir_path, "parameter_weights.npy")), dtype=torch.float32, - device=device) # (d_f,) + device=device, + ) # (d_f,) return { "border_mask": border_mask, @@ -69,8 +90,8 @@ def loads_file(fn): class BufferList(nn.Module): """ - A list of torch buffer tensors that sit together as a Module with no parameters and only - buffers. + A list of torch buffer tensors that sit together as a Module with no + parameters and only buffers. This should be replaced by a native torch BufferList once implemented. See: https://github.com/pytorch/pytorch/issues/37386 @@ -93,6 +114,9 @@ def __iter__(self): def load_graph(graph_name, device="cpu"): + """ + Load all tensors representing the graph + """ # Define helper lambda function graph_dir_path = os.path.join("graphs", graph_name) @@ -100,8 +124,9 @@ def loads_file(fn): return torch.load(os.path.join(graph_dir_path, fn), map_location=device) # Load edges (edge_index) - m2m_edge_index = BufferList(loads_file("m2m_edge_index.pt"), - persistent=False) # List of (2, M_m2m[l]) + m2m_edge_index = BufferList( + loads_file("m2m_edge_index.pt"), persistent=False + ) # List of (2, M_m2m[l]) g2m_edge_index = loads_file("g2m_edge_index.pt") # (2, M_g2m) m2g_edge_index = loads_file("m2g_edge_index.pt") # (2, M_m2g) @@ -114,50 +139,76 @@ def loads_file(fn): m2g_features = loads_file("m2g_features.pt") # (M_m2g, d_edge_f) # Normalize by dividing with longest edge (found in m2m) - longest_edge = max([torch.max(level_features[:, 0]) - for level_features in m2m_features]) # Col. 0 is length - m2m_features = BufferList([level_features / longest_edge - for level_features in m2m_features], persistent=False) + longest_edge = max( + torch.max(level_features[:, 0]) for level_features in m2m_features + ) # Col. 0 is length + m2m_features = BufferList( + [level_features / longest_edge for level_features in m2m_features], + persistent=False, + ) g2m_features = g2m_features / longest_edge m2g_features = m2g_features / longest_edge # Load static node features - mesh_static_features = loads_file("mesh_features.pt" - ) # List of (N_mesh[l], d_mesh_static) + mesh_static_features = loads_file( + "mesh_features.pt" + ) # List of (N_mesh[l], d_mesh_static) # Some checks for consistency - assert len(m2m_features) == n_levels, "Inconsistent number of levels in mesh" - assert len(mesh_static_features) == n_levels, "Inconsistent number of levels in mesh" + assert ( + len(m2m_features) == n_levels + ), "Inconsistent number of levels in mesh" + assert ( + len(mesh_static_features) == n_levels + ), "Inconsistent number of levels in mesh" if hierarchical: # Load up and down edges and features - mesh_up_edge_index = BufferList(loads_file("mesh_up_edge_index.pt"), - persistent=False) # List of (2, M_up[l]) - mesh_down_edge_index = BufferList(loads_file("mesh_down_edge_index.pt"), - persistent=False) # List of (2, M_down[l]) - - mesh_up_features = loads_file("mesh_up_features.pt" - ) # List of (M_up[l], d_edge_f) - mesh_down_features = loads_file("mesh_down_features.pt" - ) # List of (M_down[l], d_edge_f) + mesh_up_edge_index = BufferList( + loads_file("mesh_up_edge_index.pt"), persistent=False + ) # List of (2, M_up[l]) + mesh_down_edge_index = BufferList( + loads_file("mesh_down_edge_index.pt"), persistent=False + ) # List of (2, M_down[l]) + + mesh_up_features = loads_file( + "mesh_up_features.pt" + ) # List of (M_up[l], d_edge_f) + mesh_down_features = loads_file( + "mesh_down_features.pt" + ) # List of (M_down[l], d_edge_f) # Rescale mesh_up_features = BufferList( - [edge_features / longest_edge for edge_features in mesh_up_features], - persistent=False) + [ + edge_features / longest_edge + for edge_features in mesh_up_features + ], + persistent=False, + ) mesh_down_features = BufferList( - [edge_features / longest_edge for edge_features in mesh_down_features], - persistent=False) - - mesh_static_features = BufferList(mesh_static_features, persistent=False) + [ + edge_features / longest_edge + for edge_features in mesh_down_features + ], + persistent=False, + ) + + mesh_static_features = BufferList( + mesh_static_features, persistent=False + ) else: # Extract single mesh level m2m_edge_index = m2m_edge_index[0] m2m_features = m2m_features[0] mesh_static_features = mesh_static_features[0] - mesh_up_edge_index, mesh_down_edge_index, mesh_up_features, mesh_down_features =\ - [], [], [], [] + ( + mesh_up_edge_index, + mesh_down_edge_index, + mesh_up_features, + mesh_down_features, + ) = ([], [], [], []) return hierarchical, { "g2m_edge_index": g2m_edge_index, @@ -182,7 +233,7 @@ def make_mlp(blueprint, layer_norm=True): hidden layers of dimensions: blueprint[1], ..., blueprint[-2] if layer_norm is True, includes a LayerNorm layer at - the output (as used iwn GraphCast) + the output (as used in GraphCast) """ hidden_layers = len(blueprint) - 2 assert hidden_layers >= 0, "Invalid MLP blueprint" @@ -202,12 +253,16 @@ def make_mlp(blueprint, layer_norm=True): def fractional_plot_bundle(fraction): """ - Get the tueplots bundle, but with figure width as a fraction of the page width. + Get the tueplots bundle, but with figure width as a fraction of + the page width. """ bundle = bundles.neurips2023(usetex=True, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] - bundle["figure.figsize"] = (original_figsize[0] / fraction, original_figsize[1]) + bundle["figure.figsize"] = ( + original_figsize[0] / fraction, + original_figsize[1], + ) return bundle @@ -218,5 +273,5 @@ def init_wandb_metrics(wandb_logger): """ experiment = wandb_logger.experiment experiment.define_metric("val_mean_loss", summary="min") - for step in constants.val_step_log_errors: - experiment.define_metric(f"val_loss_unroll{step:02}", summary="min") + for step in constants.VAL_STEP_LOG_ERRORS: + experiment.define_metric(f"val_loss_unroll{step}", summary="min") diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 3d547aeb..6b3e4152 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -1,38 +1,52 @@ +# Third-party import cartopy.feature as cf import matplotlib import matplotlib.pyplot as plt import numpy as np import xarray as xr +# First-party from neural_lam import constants, utils from neural_lam.rotate_grid import unrotate_latlon @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_error_map(errors, global_mean, title=None, step_length=1): +def plot_error_map(errors, global_mean, step_length=1, title=None): """ - Plot a heatmap of errors of different variables at different predictions horizons + Plot a heatmap of errors of different variables at different + predictions horizons errors: (pred_steps, d_f) """ errors_np = errors.T.cpu().numpy() # (d_f, pred_steps) d_f, pred_steps = errors_np.shape - rel_errors = errors_np / np.abs(np.expand_dims(global_mean.cpu(), axis=1)) - height = int(np.sqrt(len(constants.vertical_levels) - * len(constants.param_names_short)) * 2) + errors_norm = errors_np / np.abs(np.expand_dims(global_mean.cpu(), axis=1)) + height = int( + np.sqrt( + len(constants.VERTICAL_LEVELS) * len(constants.PARAM_NAMES_SHORT) + ) + * 2 + ) fig, ax = plt.subplots(figsize=(15, height)) - ax.imshow(rel_errors, cmap="OrRd", vmin=0, vmax=1., interpolation="none", - aspect="auto", alpha=0.8) + ax.imshow( + errors_norm, + cmap="OrRd", + vmin=0, + vmax=1.0, + interpolation="none", + aspect="auto", + alpha=0.8, + ) # ax and labels for (j, i), error in np.ndenumerate(errors_np): # Numbers > 9999 will be too large to fit formatted_error = f"{error:.3f}" if error < 9999 else f"{error:.2E}" - ax.text(i, j, formatted_error, ha='center', va='center', usetex=False) + ax.text(i, j, formatted_error, ha="center", va="center", usetex=False) # Ticks and labels - label_size = 12 + label_size = 15 ax.set_xticks(np.arange(pred_steps)) pred_hor_i = np.arange(pred_steps) + 1 # Prediction horiz. in index pred_hor_h = step_length * pred_hor_i # Prediction horiz. in hours @@ -41,9 +55,17 @@ def plot_error_map(errors, global_mean, title=None, step_length=1): ax.set_yticks(np.arange(d_f)) y_ticklabels = [ - f"{name if name != 'RELHUM' else 'RH'} ({unit}) {f'{level:02}' if constants.is_3d[name] else ''}" - for name, unit in zip(constants.param_names_short, constants.param_units) - for level in (constants.vertical_levels if constants.is_3d[name] else [0])] + ( + f"{name if name != 'RELHUM' else 'RH'} ({unit}) " + f"{f'{level:02}' if constants.IS_3D[name] else ''}" + ) + for name, unit in zip( + constants.PARAM_NAMES_SHORT, constants.PARAM_UNITS + ) + for level in ( + constants.VERTICAL_LEVELS if constants.IS_3D[name] else [0] + ) + ] y_ticklabels = sorted(y_ticklabels) ax.set_yticklabels(y_ticklabels, rotation=30, size=label_size) @@ -54,7 +76,7 @@ def plot_error_map(errors, global_mean, title=None, step_length=1): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_prediction(pred, target, obs_mask, title=None, vrange=None): +def plot_prediction(pred, target, title=None, vrange=None): """ Plot example prediction and grond truth. Each has shape (N_grid,) @@ -67,32 +89,35 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): vmin, vmax = vrange[0].cpu().item(), vrange[1].cpu().item() # get test data - data_latlon = xr.open_zarr(constants.example_file).isel(time=0) + data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) lon, lat = unrotate_latlon(data_latlon) - fig, axes = plt.subplots(2, 1, figsize=constants.fig_size, - subplot_kw={"projection": constants.selected_proj}) + fig, axes = plt.subplots( + 2, + 1, + figsize=constants.FIG_SIZE, + subplot_kw={"projection": constants.SELECTED_PROJ}, + ) # Plot pred and target for ax, data in zip(axes, (target, pred)): - data_grid = data.reshape(*constants.grid_shape[::-1]).cpu().numpy() + data_grid = data.reshape(*constants.GRID_SHAPE[::-1]).cpu().numpy() contour_set = ax.contourf( lon, lat, data_grid, - transform=constants.selected_proj, + transform=constants.SELECTED_PROJ, cmap="plasma", - levels=np.linspace( - vmin, - vmax, - num=100)) - ax.add_feature(cf.BORDERS, linestyle='-', edgecolor='black') - ax.add_feature(cf.COASTLINE, linestyle='-', edgecolor='black') + levels=np.linspace(vmin, vmax, num=100), + ) + ax.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") + ax.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") ax.gridlines( - crs=constants.selected_proj, + crs=constants.SELECTED_PROJ, draw_labels=False, linewidth=0.5, - alpha=0.5) + alpha=0.5, + ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) @@ -107,7 +132,7 @@ def plot_prediction(pred, target, obs_mask, title=None, vrange=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) -def plot_spatial_error(error, obs_mask, title=None, vrange=None): +def plot_spatial_error(error, title=None, vrange=None): """ Plot errors over spatial map Error and obs_mask has shape (N_grid,) @@ -120,31 +145,29 @@ def plot_spatial_error(error, obs_mask, title=None, vrange=None): vmin, vmax = vrange[0].cpu().item(), vrange[1].cpu().item() # get test data - data_latlon = xr.open_zarr(constants.example_file).isel(time=0) + data_latlon = xr.open_zarr(constants.EXAMPLE_FILE).isel(time=0) lon, lat = unrotate_latlon(data_latlon) - fig, ax = plt.subplots(figsize=constants.fig_size, - subplot_kw={"projection": constants.selected_proj}) + fig, ax = plt.subplots( + figsize=constants.FIG_SIZE, + subplot_kw={"projection": constants.SELECTED_PROJ}, + ) - error_grid = error.reshape(*constants.grid_shape[::-1]).cpu().numpy() + error_grid = error.reshape(*constants.GRID_SHAPE[::-1]).cpu().numpy() contour_set = ax.contourf( lon, lat, error_grid, - transform=constants.selected_proj, + transform=constants.SELECTED_PROJ, cmap="OrRd", - levels=np.linspace( - vmin, - vmax, - num=100)) - ax.add_feature(cf.BORDERS, linestyle='-', edgecolor='black') - ax.add_feature(cf.COASTLINE, linestyle='-', edgecolor='black') + levels=np.linspace(vmin, vmax, num=100), + ) + ax.add_feature(cf.BORDERS, linestyle="-", edgecolor="black") + ax.add_feature(cf.COASTLINE, linestyle="-", edgecolor="black") ax.gridlines( - crs=constants.selected_proj, - draw_labels=False, - linewidth=0.5, - alpha=0.5) + crs=constants.SELECTED_PROJ, draw_labels=False, linewidth=0.5, alpha=0.5 + ) # Ticks and labels cbar = fig.colorbar(contour_set, orientation="horizontal", aspect=20) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 52d49a6b..17a00f64 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -1,136 +1,193 @@ +# Standard library import glob import os from datetime import datetime, timedelta +# Third-party import pytorch_lightning as pl import torch import xarray as xr +# First-party # BUG: Import should work in interactive mode as well -> create pypi package from neural_lam import constants, utils +# pylint: disable=W0613:unused-argument +# pylint: disable=W0201:attribute-defined-outside-init + class WeatherDataset(torch.utils.data.Dataset): """ - For our dataset: N_t = 1h N_x = 582 N_y = 390 N_grid = 582*390 = 226980 d_features = 4(features) * 21(vertical model levels) = 84 - d_forcing = 0 #TODO: extract incoming radiation from KENDA + d_forcing = 0 + #TODO: extract incoming radiation from KENDA """ - def __init__(self, dataset_name, split="train", - standardize=True, subset=False, batch_size=4): + def __init__( + self, + dataset_name, + split="train", + standardize=True, + subset=False, + batch_size=4, + control_only=False, + ): super().__init__() assert split in ("train", "val", "test"), "Unknown dataset split" - sample_dir_path = os.path.join("data", dataset_name, "samples", split) + self.sample_dir_path = os.path.join( + "data", dataset_name, "samples", split + ) self.batch_size = batch_size self.batch_index = 0 self.index_within_batch = 0 - self.zarr_files = sorted(glob.glob( - os.path.join(sample_dir_path, "data*.zarr"))) + self.zarr_files = sorted( + glob.glob(os.path.join(self.sample_dir_path, "data*.zarr")) + ) if len(self.zarr_files) == 0: raise ValueError("No .zarr files found in directory") if subset: - if constants.eval_datetime is not None and split == "test": + if constants.EVAL_DATETIME is not None and split == "test": eval_datetime_obj = datetime.strptime( - constants.eval_datetime, "%Y%m%d%H") + constants.EVAL_DATETIME, "%Y%m%d%H" + ) for i, file in enumerate(self.zarr_files): file_datetime_str = file.split("/")[-1].split("_")[1][:-5] - file_datetime_obj = datetime.strptime(file_datetime_str, "%Y%m%d%H") - if file_datetime_obj <= eval_datetime_obj < file_datetime_obj + \ - timedelta(hours=constants.chunk_size): - # Retrieve the current file and the next file if it exists + file_datetime_obj = datetime.strptime( + file_datetime_str, "%Y%m%d%H" + ) + if ( + file_datetime_obj + <= eval_datetime_obj + < file_datetime_obj + + timedelta(hours=constants.CHUNK_SIZE) + ): + # Retrieve the current file and the next file if it + # exists next_file_index = i + 1 if next_file_index < len(self.zarr_files): - self.zarr_files = [file, self.zarr_files[next_file_index]] + self.zarr_files = [ + file, + self.zarr_files[next_file_index], + ] else: self.zarr_files = [file] position_within_file = int( - (eval_datetime_obj - file_datetime_obj).total_seconds() // 3600) - self.batch_index = position_within_file // self.batch_size - self.index_within_batch = position_within_file % self.batch_size + ( + eval_datetime_obj - file_datetime_obj + ).total_seconds() + // 3600 + ) + self.batch_index = ( + position_within_file // self.batch_size + ) + self.index_within_batch = ( + position_within_file % self.batch_size + ) break else: self.zarr_files = self.zarr_files[0:2] - start_datetime = self.zarr_files[0].split( - "/")[-1].split("_")[1].replace('.zarr', '') + start_datetime = ( + self.zarr_files[0] + .split("/")[-1] + .split("_")[1] + .replace(".zarr", "") + ) print("Data subset of 200 samples starts on the", start_datetime) # Separate 3D and 2D variables - variables_3d = [var for var in constants.param_names_short - if constants.is_3d[var]] - variables_2d = [var for var in constants.param_names_short - if not constants.is_3d[var]] + variables_3d = [ + var for var in constants.PARAM_NAMES_SHORT if constants.IS_3D[var] + ] + variables_2d = [ + var + for var in constants.PARAM_NAMES_SHORT + if not constants.IS_3D[var] + ] # Stack 3D variables datasets_3d = [ - xr.open_zarr( - file, - consolidated=True)[variables_3d].sel( - z_1=constants.vertical_levels).to_array().stack( - var=( - 'variable', - 'z_1')).transpose( - "time", - "x_1", - "y_1", - "var") for file in self.zarr_files] + xr.open_zarr(file, consolidated=True)[variables_3d] + .sel(z_1=constants.VERTICAL_LEVELS) + .to_array() + .stack(var=("variable", "z_1")) + .transpose("time", "x_1", "y_1", "var") + for file in self.zarr_files + ] # Stack 2D variables without selecting along z_1 datasets_2d = [ - xr.open_zarr( - file, - consolidated=True)[variables_2d].to_array().expand_dims( - z_1=[0]).stack( - var=( - 'variable', - 'z_1')).transpose( - "time", - "x_1", - "y_1", - "var") for file in self.zarr_files] + xr.open_zarr(file, consolidated=True)[variables_2d] + .to_array() + .expand_dims(z_1=[0]) + .stack(var=("variable", "z_1")) + .transpose("time", "x_1", "y_1", "var") + for file in self.zarr_files + ] # Combine 3D and 2D datasets - self.zarr_datasets = [xr.concat([ds_3d, ds_2d], dim='var').sortby("var") - for ds_3d, ds_2d in zip(datasets_3d, datasets_2d)] + self.zarr_datasets = [ + xr.concat([ds_3d, ds_2d], dim="var").sortby("var") + for ds_3d, ds_2d in zip(datasets_3d, datasets_2d) + ] self.standardize = standardize if standardize: ds_stats = utils.load_dataset_stats(dataset_name, "cpu") - self.data_mean, self.data_std = ds_stats["data_mean"], ds_stats["data_std"] + if constants.GRID_FORCING_DIM > 0: + self.data_mean, self.data_std, self.flux_mean, self.flux_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ds_stats["flux_mean"], + ds_stats["flux_std"], + ) + else: + self.data_mean, self.data_std = ( + ds_stats["data_mean"], + ds_stats["data_std"], + ) self.random_subsample = split == "train" self.split = split def __len__(self): - num_steps = constants.train_horizon if self.split == "train" else constants.eval_horizon - total_time = len( - self.zarr_files) * constants.chunk_size - num_steps + num_steps = ( + constants.TRAIN_HORIZON + if self.split == "train" + else constants.EVAL_HORIZON + ) + total_time = len(self.zarr_files) * constants.CHUNK_SIZE - num_steps return total_time def __getitem__(self, idx): - num_steps = constants.train_horizon if self.split == "train" else constants.eval_horizon + num_steps = ( + constants.TRAIN_HORIZON + if self.split == "train" + else constants.EVAL_HORIZON + ) # Calculate which zarr files need to be loaded - start_file_idx = idx // constants.chunk_size - end_file_idx = (idx + num_steps) // constants.chunk_size + start_file_idx = idx // constants.CHUNK_SIZE + end_file_idx = (idx + num_steps) // constants.CHUNK_SIZE # Index of current slice - idx_sample = idx % constants.chunk_size + idx_sample = idx % constants.CHUNK_SIZE sample_archive = xr.concat( - self.zarr_datasets[start_file_idx: end_file_idx + 1], - dim='time') + self.zarr_datasets[start_file_idx : end_file_idx + 1], dim="time" + ) - sample_xr = sample_archive.isel(time=slice(idx_sample, idx_sample + num_steps)) + sample_xr = sample_archive.isel( + time=slice(idx_sample, idx_sample + num_steps) + ) # (N_t', N_x, N_y, d_features') sample = torch.tensor(sample_xr.values, dtype=torch.float32) @@ -138,6 +195,7 @@ def __getitem__(self, idx): sample = sample.flatten(1, 2) # (N_t, N_grid, d_features) if self.standardize: + # Standardize sample sample = (sample - self.data_mean) / self.data_std # Split up sample in init. states and target states @@ -148,8 +206,17 @@ def __getitem__(self, idx): class WeatherDataModule(pl.LightningDataModule): - def __init__(self, dataset_name, split="train", standardize=True, - subset=False, batch_size=4, num_workers=16): + """DataModule for weather data.""" + + def __init__( + self, + dataset_name, + split="train", + standardize=True, + subset=False, + batch_size=4, + num_workers=16, + ): super().__init__() self.dataset_name = dataset_name self.batch_size = batch_size @@ -158,46 +225,60 @@ def __init__(self, dataset_name, split="train", standardize=True, self.subset = subset def prepare_data(self): - # download, split, etc... - # called only on 1 GPU/TPU in distributed + # download, split, etc... called only on 1 GPU/TPU in distributed pass def setup(self, stage=None): - # make assignments here (val/train/test split) - # called on every process in DDP - if stage == 'fit' or stage is None: + # make assignments here (val/train/test split) called on every process + # in DDP + if stage == "fit" or stage is None: self.train_dataset = WeatherDataset( self.dataset_name, split="train", standardize=self.standardize, subset=self.subset, - batch_size=self.batch_size) + batch_size=self.batch_size, + ) self.val_dataset = WeatherDataset( self.dataset_name, split="val", standardize=self.standardize, subset=self.subset, - batch_size=self.batch_size) + batch_size=self.batch_size, + ) - if stage == 'test' or stage is None: + if stage == "test" or stage is None: self.test_dataset = WeatherDataset( self.dataset_name, split="test", standardize=self.standardize, subset=self.subset, - batch_size=self.batch_size) + batch_size=self.batch_size, + ) def train_dataloader(self): return torch.utils.data.DataLoader( - self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, - shuffle=False, pin_memory=False,) + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + pin_memory=False, + ) def val_dataloader(self): return torch.utils.data.DataLoader( - self.val_dataset, batch_size=self.batch_size // self.batch_size, - num_workers=self.num_workers, shuffle=False, pin_memory=False,) + self.val_dataset, + batch_size=self.batch_size // self.batch_size, + num_workers=self.num_workers, + shuffle=False, + pin_memory=False, + ) def test_dataloader(self): return torch.utils.data.DataLoader( - self.test_dataset, batch_size=self.batch_size, - num_workers=self.num_workers, shuffle=False, pin_memory=False) + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + pin_memory=False, + ) diff --git a/plot_graph.py b/plot_graph.py index 0c3e3156..48427d5c 100644 --- a/plot_graph.py +++ b/plot_graph.py @@ -1,9 +1,12 @@ +# Standard library from argparse import ArgumentParser +# Third-party import numpy as np import plotly.graph_objects as go import torch_geometric as pyg +# First-party from neural_lam import utils MESH_HEIGHT = 0.1 @@ -12,36 +15,64 @@ def main(): - parser = ArgumentParser(description='Plot graph') + """ + Plot graph structure in 3D using plotly + """ + parser = ArgumentParser(description="Plot graph") parser.add_argument( - '--dataset', type=str, default="meps_example", - help='Datast to load grid coordinates from (default: meps_example)') - parser.add_argument('--graph', type=str, default="multiscale", - help='Graph to plot (default: multiscale)') + "--dataset", + type=str, + default="meps_example", + help="Datast to load grid coordinates from (default: meps_example)", + ) parser.add_argument( - '--save', type=str, - help='Name of .html file to save interactive plot to (default: None)') - parser.add_argument('--show_axis', type=int, default=0, - help='If the axis should be displayed (default: 0 (No))') + "--graph", + type=str, + default="multiscale", + help="Graph to plot (default: multiscale)", + ) + parser.add_argument( + "--save", + type=str, + help="Name of .html file to save interactive plot to (default: None)", + ) + parser.add_argument( + "--show_axis", + type=int, + default=0, + help="If the axis should be displayed (default: 0 (No))", + ) args = parser.parse_args() # Load graph data hierarchical, graph_ldict = utils.load_graph(args.graph) - g2m_edge_index, m2g_edge_index, m2m_edge_index, =\ - graph_ldict["g2m_edge_index"], graph_ldict["m2g_edge_index"], \ - graph_ldict["m2m_edge_index"] - mesh_up_edge_index, mesh_down_edge_index = graph_ldict["mesh_up_edge_index"], \ - graph_ldict["mesh_down_edge_index"] + ( + g2m_edge_index, + m2g_edge_index, + m2m_edge_index, + ) = ( + graph_ldict["g2m_edge_index"], + graph_ldict["m2g_edge_index"], + graph_ldict["m2m_edge_index"], + ) + mesh_up_edge_index, mesh_down_edge_index = ( + graph_ldict["mesh_up_edge_index"], + graph_ldict["mesh_down_edge_index"], + ) mesh_static_features = graph_ldict["mesh_static_features"] - grid_static_features = utils.load_static_data(args.dataset)["grid_static_features"] + grid_static_features = utils.load_static_data(args.dataset)[ + "grid_static_features" + ] # Extract values needed, turn to numpy grid_pos = grid_static_features[:, :2].numpy() # Add in z-dimension z_grid = GRID_HEIGHT * np.ones((grid_pos.shape[0],)) - grid_pos = np.concatenate((grid_pos, np.expand_dims(z_grid, axis=1)), axis=1) + grid_pos = np.concatenate( + (grid_pos, np.expand_dims(z_grid, axis=1)), axis=1 + ) # List of edges to plot, (edge_index, color, line_width, label) edge_plot_list = [ @@ -49,26 +80,39 @@ def main(): (g2m_edge_index.numpy(), "black", 0.4, "G2M"), ] - # Mesh positioning and edges to plot differ if we have a hierachical graph + # Mesh positioning and edges to plot differ if we have a hierarchical graph if hierarchical: - mesh_level_pos = [np.concatenate(( - level_static_features.numpy(), - MESH_HEIGHT + MESH_LEVEL_DIST * height_level * np.ones( - (level_static_features.shape[0], 1)), - ), axis=1) - for height_level, level_static_features - in enumerate(mesh_static_features, start=1)] + mesh_level_pos = [ + np.concatenate( + ( + level_static_features.numpy(), + MESH_HEIGHT + + MESH_LEVEL_DIST + * height_level + * np.ones((level_static_features.shape[0], 1)), + ), + axis=1, + ) + for height_level, level_static_features in enumerate( + mesh_static_features, start=1 + ) + ] mesh_pos = np.concatenate(mesh_level_pos, axis=0) # Add inter-level mesh edges - edge_plot_list += [(level_ei.numpy(), "blue", 1, f"M2M Level {level}") - for level, level_ei in enumerate(m2m_edge_index)] + edge_plot_list += [ + (level_ei.numpy(), "blue", 1, f"M2M Level {level}") + for level, level_ei in enumerate(m2m_edge_index) + ] # Add intra-level mesh edges - up_edges_ei = np.concatenate([level_up_ei.numpy() - for level_up_ei in mesh_up_edge_index], axis=1) - down_edges_ei = np.concatenate([level_down_ei.numpy() - for level_down_ei in mesh_down_edge_index], axis=1) + up_edges_ei = np.concatenate( + [level_up_ei.numpy() for level_up_ei in mesh_up_edge_index], axis=1 + ) + down_edges_ei = np.concatenate( + [level_down_ei.numpy() for level_down_ei in mesh_down_edge_index], + axis=1, + ) edge_plot_list.append((up_edges_ei, "green", 1, "Mesh up")) edge_plot_list.append((down_edges_ei, "green", 1, "Mesh down")) @@ -80,7 +124,9 @@ def main(): z_mesh = MESH_HEIGHT + 0.01 * mesh_degrees mesh_node_size = mesh_degrees / 2 - mesh_pos = np.concatenate((mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1) + mesh_pos = np.concatenate( + (mesh_pos, np.expand_dims(z_mesh, axis=1)), axis=1 + ) edge_plot_list.append((m2m_edge_index.numpy(), "blue", 1, "M2M")) @@ -89,21 +135,34 @@ def main(): # Add edges data_objs = [] - for ei, col, width, label, in edge_plot_list: + for ( + ei, + col, + width, + label, + ) in edge_plot_list: edge_start = node_pos[ei[0]] # (M, 2) edge_end = node_pos[ei[1]] # (M, 2) n_edges = edge_start.shape[0] - x_edges = np.stack((edge_start[:, 0], edge_end[:, 0], np.full(n_edges, None)), - axis=1).flatten() - y_edges = np.stack((edge_start[:, 1], edge_end[:, 1], np.full(n_edges, None)), - axis=1).flatten() - z_edges = np.stack((edge_start[:, 2], edge_end[:, 2], np.full(n_edges, None)), - axis=1).flatten() + x_edges = np.stack( + (edge_start[:, 0], edge_end[:, 0], np.full(n_edges, None)), axis=1 + ).flatten() + y_edges = np.stack( + (edge_start[:, 1], edge_end[:, 1], np.full(n_edges, None)), axis=1 + ).flatten() + z_edges = np.stack( + (edge_start[:, 2], edge_end[:, 2], np.full(n_edges, None)), axis=1 + ).flatten() scatter_obj = go.Scatter3d( - x=x_edges, y=y_edges, z=z_edges, mode='lines', line={ - "color": col, "width": width}, name=label) + x=x_edges, + y=y_edges, + z=z_edges, + mode="lines", + line={"color": col, "width": width}, + name=label, + ) data_objs.append(scatter_obj) # Add node objects @@ -113,29 +172,35 @@ def main(): x=grid_pos[:, 0], y=grid_pos[:, 1], z=grid_pos[:, 2], - mode='markers', marker={"color": "black", "size": 1}, - name="Grid nodes")) + mode="markers", + marker={"color": "black", "size": 1}, + name="Grid nodes", + ) + ) data_objs.append( go.Scatter3d( x=mesh_pos[:, 0], y=mesh_pos[:, 1], z=mesh_pos[:, 2], - mode='markers', marker={"color": "blue", "size": mesh_node_size}, - name="Mesh nodes")) + mode="markers", + marker={"color": "blue", "size": mesh_node_size}, + name="Mesh nodes", + ) + ) fig = go.Figure(data=data_objs) - fig.update_layout(scene_aspectmode='data') + fig.update_layout(scene_aspectmode="data") fig.update_traces(connectgaps=False) if not args.show_axis: # Hide axis fig.update_layout( - scene=dict( - xaxis=dict(visible=False), - yaxis=dict(visible=False), - zaxis=dict(visible=False) - ) + scene={ + "xaxis": {"visible": False}, + "yaxis": {"visible": False}, + "zaxis": {"visible": False}, + } ) if args.save: diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..55c07c25 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,66 @@ +[tool.black] +line-length = 80 + +[tool.isort] +default_section = "THIRDPARTY" +profile = "black" +# Headings +import_heading_stdlib = "Standard library" +import_heading_thirdparty = "Third-party" +import_heading_firstparty = "First-party" +import_heading_localfolder = "Local" +# Known modules to avoid misclassification +known_standard_library = [ + # Add standard library modules that may be misclassified by isort +] +known_third_party = [ + # Add third-party modules that may be misclassified by isort + "wandb", +] +known_first_party = [ + # Add first-party modules that may be misclassified by isort + "neural_lam", +] + +[tool.flake8] +max-line-length = 80 +ignore = [ + "E203", # Allow whitespace before ':' (https://github.com/PyCQA/pycodestyle/issues/373) + "I002", # Don't check for isort configuration + "W503", # Allow line break before binary operator (PEP 8-compatible) +] +per-file-ignores = [ + "__init__.py: F401", # Allow unused imports +] + +[tool.codespell] +skip = "requirements/*" + +# Pylint config +[tool.pylint] +ignore = [ + "create_mesh.py", # Disable linting for now, as major rework is planned/expected +] +# Temporary fix for import neural_lam statements until set up as proper package +init-hook='import sys; sys.path.append(".")' +[tool.pylint.TYPECHECK] +generated-members = [ + "numpy.*", + "torch.*", +] +[tool.pylint.'MESSAGES CONTROL'] +disable = [ + "C0114", # 'missing-module-docstring', Do not require module docstrings + "R0901", # 'too-many-ancestors', Allow many layers of sub-classing + "R0902", # 'too-many-instance-attribtes', Allow many attributes + "R0913", # 'too-many-arguments', Allow many function arguments + "R0914", # 'too-many-locals', Allow many local variables + "W0223", # 'abstract-method', Subclasses do not have to override all abstract methods +] +[tool.pylint.DESIGN] +max-statements=100 # Allow for some more involved functions +[tool.pylint.IMPORTS] +allow-any-import-level="neural_lam" +known-third-party="wandb" +[tool.pylint.SIMILARITIES] +min-similarity-lines=10 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..f99002c2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,22 @@ +# for all +numpy>=1.24.2 +wandb>=0.13.10 +matplotlib>=3.7.0 +scipy>=1.10.0 +pytorch-lightning>=2.0.3 +shapely>=2.0.1 +networkx>=3.0 +imageio>=2.34.0 +numcodecs>=0.12.1 +Cartopy>=0.22.0 +pyproj>=3.4.1 +xarray>=2024.1.1 +tueplots>=0.0.8 +plotly>=5.15.0 +# for dev +codespell>=2.0.0 +black>=21.9b0 +isort>=5.9.3 +flake8>=4.0.1 +pylint>=3.0.3 +pre-commit>=2.15.0 diff --git a/slurm_eval.sh b/slurm_eval.sh index 3553d20f..3995292c 100644 --- a/slurm_eval.sh +++ b/slurm_eval.sh @@ -1,13 +1,13 @@ #!/bin/bash -l #SBATCH --job-name=NeurWPe #SBATCH --nodes=1 -#SBATCH --gpus-per-node=1 -#SBATCH --ntasks-per-node=1 +#SBATCH --ntasks-per-node=4 #SBATCH --partition=a100-80gb #SBATCH --account=s83 #SBATCH --output=lightning_logs/neurwp_eval_out.log #SBATCH --error=lightning_logs/neurwp_eval_err.log #SBATCH --time=03:00:00 +#SBATCH --no-requeue export PREPROCESS=true export NORMALIZE=false @@ -16,11 +16,15 @@ export NORMALIZE=false conda activate neural-lam if [ "$PREPROCESS" = true ]; then + echo "Create static features" srun -ul -N1 -n1 python create_static_features.py --boundaries 60 + echo "Creating mesh" srun -ul -N1 -n1 python create_mesh.py --dataset "cosmo" --plot 1 + echo "Creating grid features" srun -ul -N1 -n1 python create_grid_features.py --dataset "cosmo" if [ "$NORMALIZE" = true ]; then # This takes multiple hours! + echo "Creating normalization weights" srun -ul -N1 -n1 python create_parameter_weights.py --dataset "cosmo" --batch_size 32 --n_workers 8 --step_length 1 fi fi @@ -28,6 +32,5 @@ fi ulimit -c 0 export OMP_NUM_THREADS=16 -# Run the script with torchrun srun -ul python train_model.py --load "wandb/example.ckpt" --dataset "cosmo" \ - --eval="test" --subset_ds 1 --n_workers 2 --batch_size 6 --wandb_mode "offline" + --eval="test" --subset_ds 1 --n_workers 2 --batch_size 6 diff --git a/slurm_train.sh b/slurm_train.sh index 1c16304c..1930e2b1 100644 --- a/slurm_train.sh +++ b/slurm_train.sh @@ -1,7 +1,6 @@ #!/bin/bash -l #SBATCH --job-name=NeurWP #SBATCH --nodes=1 -#SBATCH --gpus-per-node=4 #SBATCH --ntasks-per-node=4 #SBATCH --partition=a100-80gb #SBATCH --account=s83 @@ -34,5 +33,5 @@ ulimit -c 0 export OMP_NUM_THREADS=16 # Run the script with torchrun -srun -ul --gpus-per-task=1 python train_model.py --dataset "cosmo" --val_interval 5 \ - --epochs 10 --n_workers 6 --batch_size 8 --subset_ds 1 --wandb_mode "offline" +srun -ul python train_model.py --dataset "cosmo" --val_interval 5 \ + --epochs 10 --n_workers 6 --batch_size 8 --subset_ds 1 diff --git a/train_model.py b/train_model.py index 64fb98e7..264c536d 100644 --- a/train_model.py +++ b/train_model.py @@ -1,14 +1,16 @@ +# Standard library import os -import resource import time from argparse import ArgumentParser +# Third-party import pytorch_lightning as pl import torch +import wandb from lightning_fabric.utilities import seed from pytorch_lightning.utilities import rank_zero_only -import wandb +# First-party from neural_lam import constants, utils from neural_lam.models.graph_lam import GraphLAM from neural_lam.models.hi_lam import HiLAM @@ -24,6 +26,7 @@ @rank_zero_only def print_args(args): + """Print arguments""" print("Arguments:") for arg in vars(args): print(f"{arg}: {getattr(args, arg)}") @@ -31,127 +34,226 @@ def print_args(args): @rank_zero_only def print_eval(args_eval): + """Print evaluation""" print(f"Running evaluation on {args_eval}") @rank_zero_only def init_wandb(args): + """Initialize wandb""" if args.resume_run is None: prefix = "subset-" if args.subset_ds else "" if args.eval: prefix = prefix + f"eval-{args.eval}-" - run_name = f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-"\ + run_name = ( + f"{prefix}{args.model}-{args.processor_layers}x{args.hidden_dim}-" f"{time.strftime('%m_%d_%H_%M_%S')}" + ) wandb.init( name=run_name, - project=constants.wandb_project, + project=constants.WANDB_PROJECT, config=args, - mode=args.wandb_mode ) logger = pl.loggers.WandbLogger( - project=constants.wandb_project, + project=constants.WANDB_PROJECT, name=run_name, config=args, - log_model=True) + log_model=True, + ) wandb.save("slurm_train.sh") - wandb.save("neural_lam/constants.py") + wandb.save("neural_lam/constants.PY") else: wandb.init( - project=constants.wandb_project, + project=constants.WANDB_PROJECT, config=args, id=args.resume_run, - resume='must', - mode=args.wandb_mode + resume="must", ) logger = pl.loggers.WandbLogger( - project=constants.wandb_project, + project=constants.WANDB_PROJECT, id=args.resume_run, config=args, - log_model=True) + log_model=True, + ) return logger def main(): - # if torch.cuda.is_available(): - # init_process_group(backend="nccl") - parser = ArgumentParser(description='Train or evaluate NeurWP models for LAM') + # pylint: disable=too-many-branches + """ + Main function for training and evaluating models + """ + parser = ArgumentParser( + description="Train or evaluate NeurWP models for LAM" + ) # General options parser.add_argument( - '--dataset', type=str, default="meps_example", - help='Dataset, corresponding to name in data directory (default: meps_example)') - parser.add_argument( - '--model', type=str, default="graph_lam", - help='Model architecture to train/evaluate (default: graph_lam)') - parser.add_argument( - '--subset_ds', type=int, default=0, - help='Use only a small subset of the dataset, for debugging (default: 0=false)') - parser.add_argument('--seed', type=int, default=42, - help='random seed (default: 42)') - parser.add_argument('--n_workers', type=int, default=4, - help='Number of workers in data loader (default: 4)') - parser.add_argument('--epochs', type=int, default=200, - help='upper epoch limit (default: 200)') - parser.add_argument('--batch_size', type=int, default=4, - help='batch size (default: 4)') - parser.add_argument('--load', type=str, - help='Path to load model parameters from (default: None)') - parser.add_argument('--resume_run', type=str, - help='Run ID to resume (default: None)') - parser.add_argument('--resume_opt_sched', type=int, default=0, - help='Resume optimizer and scheduler state (default: 0=false)') - parser.add_argument( - '--precision', type=str, default=32, - help='Numerical precision to use for model (32/16/bf16) (default: 32)') - parser.add_argument('--wandb_mode', type=str, default="online", - help='Wandb mode (online/offline/dryrun) (default: online)') + "--dataset", + type=str, + default="meps_example", + help="Dataset, corresponding to name in data directory " + "(default: meps_example)", + ) + parser.add_argument( + "--model", + type=str, + default="graph_lam", + help="Model architecture to train/evaluate (default: graph_lam)", + ) + parser.add_argument( + "--subset_ds", + type=int, + default=0, + help="Use only a small subset of the dataset, for debugging" + "(default: 0=false)", + ) + parser.add_argument( + "--seed", type=int, default=42, help="random seed (default: 42)" + ) + parser.add_argument( + "--n_workers", + type=int, + default=4, + help="Number of workers in data loader (default: 4)", + ) + parser.add_argument( + "--epochs", + type=int, + default=200, + help="upper epoch limit (default: 200)", + ) + parser.add_argument( + "--batch_size", type=int, default=4, help="batch size (default: 4)" + ) + parser.add_argument( + "--load", + type=str, + help="Path to load model parameters from (default: None)", + ) + parser.add_argument( + "--resume_run", type=str, help="Run ID to resume (default: None)" + ) + parser.add_argument( + "--restore_opt", + type=int, + default=0, + help="If optimizer state should be restored with model " + "(default: 0 (false))", + ) + parser.add_argument( + "--precision", + type=str, + default=32, + help="Numerical precision to use for model (32/16/bf16) (default: 32)", + ) # Model architecture parser.add_argument( - '--graph', type=str, default="multiscale", - help='Graph to load and use in graph-based model (default: multiscale)') + "--graph", + type=str, + default="multiscale", + help="Graph to load and use in graph-based model " + "(default: multiscale)", + ) parser.add_argument( - '--hidden_dim', type=int, default=64, - help='Dimensionality of all hidden representations (default: 64)') - parser.add_argument('--hidden_layers', type=int, default=1, - help='Number of hidden layers in all MLPs (default: 1)') - parser.add_argument('--processor_layers', type=int, default=4, - help='Number of GNN layers in processor GNN (default: 4)') + "--hidden_dim", + type=int, + default=64, + help="Dimensionality of all hidden representations (default: 64)", + ) parser.add_argument( - '--mesh_aggr', type=str, default="sum", - help='Aggregation to use for m2m processor GNN layers (sum/mean) (default: sum)') + "--hidden_layers", + type=int, + default=1, + help="Number of hidden layers in all MLPs (default: 1)", + ) + parser.add_argument( + "--processor_layers", + type=int, + default=4, + help="Number of GNN layers in processor GNN (default: 4)", + ) + parser.add_argument( + "--mesh_aggr", + type=str, + default="sum", + help="Aggregation to use for m2m processor GNN layers (sum/mean) " + "(default: sum)", + ) + parser.add_argument( + "--output_std", + type=int, + default=0, + help="If models should additionally output std.-dev. per " + "output dimensions " + "(default: 0 (no))", + ) # Training options parser.add_argument( - '--ar_steps', type=int, default=1, - help='Number of steps to unroll prediction for in loss (1-24) (default: 1)') - parser.add_argument('--loss', type=str, default="mse", - help='Loss function to use (default: mse)') + "--ar_steps", + type=int, + default=1, + help="Number of steps to unroll prediction for in loss (1-19) " + "(default: 1)", + ) parser.add_argument( - '--step_length', type=int, default=1, - help='Step length in hours to consider single time step 1-3 (default: 1)') - parser.add_argument('--lr', type=float, default=1e-3, - help='learning rate (default: 0.001)') + "--control_only", + type=int, + default=0, + help="Train only on control member of ensemble data " + "(default: 0 (False))", + ) parser.add_argument( - '--val_interval', type=int, default=1, - help='Number of epochs training between each validation run (default: 1)') + "--loss", + type=str, + default="wmse", + help="Loss function to use, see metric.py (default: wmse)", + ) + parser.add_argument( + "--step_length", + type=int, + default=1, + help="Step length in hours to consider single time step 1-3 " + "(default: 1)", + ) + parser.add_argument( + "--lr", type=float, default=1e-3, help="learning rate (default: 0.001)" + ) + parser.add_argument( + "--val_interval", + type=int, + default=1, + help="Number of epochs training between each validation run " + "(default: 1)", + ) # Evaluation options parser.add_argument( - '--eval', type=str, default=None, - help='Eval model on given data split (val/test) (default: None (train model))') + "--eval", + type=str, + help="Eval model on given data split (val/test) " + "(default: None (train model))", + ) parser.add_argument( - '--n_example_pred', type=int, default=1, - help='Number of example predictions to plot during evaluation (default: 1)') + "--n_example_pred", + type=int, + default=1, + help="Number of example predictions to plot during evaluation " + "(default: 1)", + ) args = parser.parse_args() + # Asserts for arguments assert args.model in MODELS, f"Unknown model: {args.model}" assert args.step_length <= 3, "Too high step length" - assert args.eval in (None, "val", "test"), f"Unknown eval setting: {args.eval}" - assert args.loss in ("mse", "mae", "huber"), f"Unknown loss function: {args.loss}" - - resource.setrlimit(resource.RLIMIT_CORE, (0, resource.RLIM_INFINITY)) + assert args.eval in ( + None, + "val", + "test", + ), f"Unknown eval setting: {args.eval}" # Set seed seed.seed_everything(args.seed) @@ -161,12 +263,14 @@ def main(): args.dataset, subset=bool(args.subset_ds), batch_size=args.batch_size, - num_workers=args.n_workers + num_workers=args.n_workers, ) - # Get the device for the current process + # Instantiate model + trainer if torch.cuda.is_available(): - torch.set_float32_matmul_precision("high") # Allows using Tensor Cores on A100s + torch.set_float32_matmul_precision( + "high" + ) # Allows using Tensor Cores on A100s # Load model parameters Use new args for model model_class = MODELS[args.model] @@ -190,7 +294,6 @@ def main(): save_on_train_epoch_end=True, verbose=True, ) - if args.eval: use_distributed_sampler = False else: @@ -199,7 +302,7 @@ def main(): if torch.cuda.is_available(): accelerator = "cuda" devices = torch.cuda.device_count() - num_nodes = int(os.environ.get('SLURM_JOB_NUM_NODES', 1)) + num_nodes = int(os.environ.get("SLURM_JOB_NUM_NODES", 1)) else: accelerator = "cpu" devices = 1 @@ -209,7 +312,9 @@ def main(): max_epochs=args.epochs, logger=logger, log_every_n_steps=1, - callbacks=[checkpoint_callback] if checkpoint_callback is not None else [], + callbacks=( + [checkpoint_callback] if checkpoint_callback is not None else [] + ), check_val_every_n_epoch=args.val_interval, precision=args.precision, use_distributed_sampler=use_distributed_sampler, @@ -237,12 +342,14 @@ def main(): # Train model data_module.split = "train" if args.load: - trainer.fit(model=model, datamodule=data_module, ckpt_path=args.load) + trainer.fit( + model=model, datamodule=data_module, ckpt_path=args.load + ) else: trainer.fit(model=model, datamodule=data_module) # Print profiler - print(trainer.profiler) + print(trainer.profiler) # pylint: disable=no-member if __name__ == "__main__":