Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge upstream #10

Merged
merged 28 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
b20c77c
Update readme with link to example data
joeloskarsson Oct 12, 2023
9ef74e4
Add MIT license
joeloskarsson Oct 12, 2023
89a4c63
Implement multi-GPU training using DDP
joeloskarsson Oct 13, 2023
6866989
Update readme with links to ideas for making code area-agnostic
joeloskarsson Oct 24, 2023
d24c0a8
Fix bug where some mesh features were set as persistent buffers
joeloskarsson Nov 1, 2023
e5f1ad3
Refactor all InteractionNet instances to all use same general class
joeloskarsson Oct 24, 2023
2d86715
Add option to train only on control member of ensemble dataset
joeloskarsson Oct 31, 2023
6377d44
Fix bug in test rmse computation, causing incorrect values (generally…
joeloskarsson Nov 6, 2023
9912ece
Make sure wandb is initialized before defining metrics, also for pyto…
joeloskarsson Nov 9, 2023
cd94f57
Change run id format to avoid name collisions
joeloskarsson Nov 25, 2023
2378ed7
Update bibtex in readme
joeloskarsson Dec 13, 2023
c14b6b4
Introduce metrics module with new loss options
joeloskarsson Jan 9, 2024
474bad9
Add pre-commit configuration for linting and formatting (#6)
sadamov Feb 1, 2024
1cddf09
Fix github pre-commit action using incomplete python env (#8)
joeloskarsson Feb 1, 2024
0669ff4
Re-define RMSE metric to take sqrt after sample averaging (#10)
joeloskarsson Feb 29, 2024
4539819
Merge remote-tracking branch 'upstream/main' into merge_upstream
sadamov Mar 1, 2024
0793684
fixing format
sadamov Mar 1, 2024
96f895f
pre-commit and merge complete
sadamov Mar 2, 2024
9d592a2
Fix formating and bugs after merge with upstream/main
sadamov Mar 6, 2024
000275a
set preprocess=true for simpler first use
sadamov Mar 6, 2024
4bbe43c
Install same linters/formatters as used in pre-commit
sadamov Mar 6, 2024
ed305ca
Set default step_len to 1
sadamov Mar 6, 2024
daf95ff
Fixed import order with flake8 as it is not consistent with ruff
sadamov Mar 6, 2024
de0437e
Merge remote-tracking branch 'upstream/main' into merge_upstream
sadamov Mar 6, 2024
c364ef4
Removed .vscode folder
sadamov Mar 7, 2024
f4c6ebe
fix import order
twicki Mar 7, 2024
ca80852
fix flake8 issues
twicki Mar 7, 2024
be9175e
Merge pull request #11 from twicki/fix_upstream
sadamov Mar 7, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Run pre-commit job

on:
push:
branches:
- main
pull_request:
branches:
- main

jobs:
pre-commit-job:
runs-on: ubuntu-latest
defaults:
run:
shell: bash -l {0}
steps:
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install pre-commit hooks
run: |
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 \
--index-url https://download.pytorch.org/whl/cpu
pip install -r requirements.txt
pip install pyg-lib==0.2.0 torch-scatter==2.1.1 torch-sparse==0.6.17 \
torch-cluster==1.6.1 torch-geometric==2.3.1 \
-f https://pytorch-geometric.com/whl/torch-2.0.1+cpu.html
- name: Run pre-commit hooks
run: |
pre-commit run --all-files
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ graphs
sweeps
test_*.sh
lightning_logs
.vscode

### Python ###
# Byte-compiled / optimized / DLL files
Expand Down
51 changes: 51 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.5.0
hooks:
- id: check-ast
- id: check-case-conflict
- id: check-docstring-first
- id: check-symlinks
- id: check-toml
- id: check-yaml
- id: debug-statements
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: local
hooks:
- id: codespell
name: codespell
description: Check for spelling errors
language: system
entry: codespell
- repo: local
hooks:
- id: black
name: black
description: Format Python code
language: system
entry: black
types_or: [python, pyi]
- repo: local
hooks:
- id: isort
name: isort
description: Group and sort Python imports
language: system
entry: isort
types_or: [python, pyi, cython]
- repo: local
hooks:
- id: flake8
name: flake8
description: Check Python code for correctness, consistency and adherence to best practices
language: system
entry: flake8 --max-line-length=80 --ignore=E203,F811,I002,W503
types: [python]
- repo: local
hooks:
- id: pylint
name: pylint
entry: pylint -rn -sn
language: system
types: [python]
22 changes: 0 additions & 22 deletions .vscode/launch.json

This file was deleted.

3 changes: 0 additions & 3 deletions .vscode/settings.json

This file was deleted.

34 changes: 24 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
<p align="middle">
<img src="figures/neural_lam_header.png" width="700">
</p>

Neural-LAM is a repository of graph-based neural weather prediction models for Limited Area Modeling (LAM).
The code uses [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/pytorch-lightning).
Graph Neural Networks are implemented using [PyG](https://pyg.org/) and logging is set up through [Weights & Biases](https://wandb.ai/).
Expand All @@ -11,16 +12,18 @@ The repository contains LAM versions of:
* GraphCast, by [Lam et al. (2023)](https://arxiv.org/abs/2212.12794).
* The hierarchical model from [Oskarsson et al. (2023)](https://arxiv.org/abs/2309.17370).

For more information see our preprint: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370).
For more information see our paper: [*Graph-based Neural Weather Prediction for Limited Area Modeling*](https://arxiv.org/abs/2309.17370).
If you use Neural-LAM in your work, please cite:
```
@article{oskarsson2023graphbased,
title={Graph-based Neural Weather Prediction for Limited Area Modeling},
author={Joel Oskarsson and Tomas Landelius and Fredrik Lindsten},
year={2023},
journal={arXiv preprint arXiv:2309.17370}
@inproceedings{oskarsson2023graphbased,
title={Graph-based Neural Weather Prediction for Limited Area Modeling},
author={Oskarsson, Joel and Landelius, Tomas and Lindsten, Fredrik},
booktitle={NeurIPS 2023 Workshop on Tackling Climate Change with Machine Learning},
year={2023}
}
```
As the code in the repository is continuously evolving, the latest version might feature some small differences to what was used in the paper.
See the branch [`ccai_paper_2023`](https://github.com/joeloskarsson/neural-lam/tree/ccai_paper_2023) for a revision of the code that reproduces the workshop paper.

We plan to continue updating this repository as we improve existing models and develop new ones.
Collaborations around this implementation are very welcome.
Expand All @@ -47,10 +50,11 @@ mamba env create -f environment.yml
mamba activate neural-lam

# Run the preprocessing/training scripts
# (don't execute preprocessing scripts at the same time as training)
sbatch slurm_train.sh

# Run the evaluation script and generate plots and gif for TQV
# (don't execute preprocessing scripts at the same time as training)
# (by default this will use the pre-trained model from `wandb/example.ckpt`)
sbatch slurm_eval.sh

```
Expand Down Expand Up @@ -101,9 +105,9 @@ Note that only the cuda version is pinned to 11.8, otherwise all the latest libr
</span>

\
Follow the steps below to create the neccesary python environment.
Follow the steps below to create the necessary python environment.

1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is neccesary for the Cartopy requirement.
1. Install GEOS for your system. For example with `sudo apt-get install libgeos-dev`. This is necessary for the Cartopy requirement.
2. Use python 3.9.
3. Install version 2.0.1 of PyTorch. Follow instructions on the [PyTorch webpage](https://pytorch.org/get-started/previous-versions/) for how to set this up with GPU support on your system.
4. Install required packages specified in `requirements.txt`.
Expand Down Expand Up @@ -233,7 +237,7 @@ python train_model.py --model hi_lam --graph hierarchical ...
```

### Hi-LAM-Parallel
A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in paralell.
A version of Hi-LAM where all message passing in the hierarchical mesh (up, down, inter-level) is ran in parallel.
Not included in the paper as initial experiments showed worse results than Hi-LAM, but could be interesting to try in more settings.

To train Hi-LAM-Parallel use
Expand Down Expand Up @@ -343,6 +347,16 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w
These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels).
Entries 0 in these lists describe edges between the lowest levels 1 and 2.

# Development and Contributing
Any push or Pull-Request to the main branch will trigger a selection of pre-commit hooks.
These hooks will run a series of checks on the code, like formatting and linting.
If any of these checks fail the push or PR will be rejected.
To test whether your code passes these checks before pushing, run
``` bash
pre-commit run --all-files
```
from the root directory of the repository.

# Contact
If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch.
You can open a github issue on this page, or (if more suitable) send an email to [[email protected]](mailto:[email protected]).
55 changes: 35 additions & 20 deletions create_grid_features.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,62 @@
# Standard library
import os
from argparse import ArgumentParser

# Third-party
import numpy as np
import torch


def main():
parser = ArgumentParser(description='Training arguments')
parser.add_argument('--dataset', type=str, default="meps_example",
help='Dataset to compute weights for (default: meps_example)')
"""
Pre-compute all static features related to the grid nodes
"""
parser = ArgumentParser(description="Training arguments")
parser.add_argument(
"--dataset",
type=str,
default="meps_example",
help="Dataset to compute weights for (default: meps_example)",
)
args = parser.parse_args()

static_dir_path = os.path.join("data", args.dataset, "static")

# -- Static grid node features --
grid_xy = torch.tensor(np.load(os.path.join(static_dir_path, "nwp_xy.npy")
)) # (2, N_x, N_y)
grid_xy = torch.tensor(
np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
) # (2, N_x, N_y)
grid_xy = grid_xy.flatten(1, 2).T # (N_grid, 2)
pos_max = torch.max(torch.abs(grid_xy))
grid_xy = grid_xy / pos_max # Divide by maximum coordinate

geopotential = torch.tensor(
np.load(
os.path.join(
static_dir_path,
"reference_geopotential_pressure.npy"))) # (N_x, N_y)
geopotential = geopotential.flatten(0, 1) # (N_grid, N_static)
os.path.join(static_dir_path, "reference_geopotential_pressure.npy")
)
) # (N_x, N_y, N_fields)
geopotential = geopotential.flatten(0, 1) # (N_grid, N_fields)
gp_min = torch.min(geopotential)
gp_max = torch.max(geopotential)
# Rescale geopotential to [0,1]
geopotential = (geopotential - gp_min) / (
gp_max - gp_min
) # (N_grid, N_fields)

grid_border_mask = torch.tensor(
np.load(
os.path.join(
static_dir_path,
"border_mask.npy")),
dtype=torch.int64) # (N_x, N_y)
grid_border_mask = grid_border_mask.flatten(0, 1).to(
torch.float).unsqueeze(1) # (N_grid, 1)
np.load(os.path.join(static_dir_path, "border_mask.npy")),
dtype=torch.int64,
) # (N_x, N_y)
grid_border_mask = (
grid_border_mask.flatten(0, 1).to(torch.float).unsqueeze(1)
) # (N_grid, 1)

# Concatenate grid features
grid_features = torch.cat((grid_xy, geopotential, grid_border_mask),
dim=1) # (N_grid, 3 + N_static)
grid_features = torch.cat(
(grid_xy, geopotential, grid_border_mask), dim=1
) # (N_grid, 4)

torch.save(grid_features, os.path.join(
static_dir_path, "grid_features.pt"))
torch.save(grid_features, os.path.join(static_dir_path, "grid_features.pt"))


if __name__ == "__main__":
Expand Down
Loading
Loading