Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/mllam/neural-lam into maint…
Browse files Browse the repository at this point in the history
…/deps-in-pyproject-toml
  • Loading branch information
leifdenby committed Aug 19, 2024
2 parents 9f3c014 + a54c45f commit 41364a8
Show file tree
Hide file tree
Showing 22 changed files with 134 additions and 89 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,4 @@ repos:
hooks:
- id: flake8
description: Check Python code for correctness, consistency and adherence to best practices
additional_dependencies: [Flake8-pyproject]
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
[\#68](https://github.com/mllam/neural-lam/pull/68)
@joeloskarsson

- turn `neural-lam` into a python package by moving all `*.py`-files into the
`neural_lam/` source directory and updating imports accordingly. This means
all cli functions are now invoke through the package name, e.g. `python -m
neural_lam.train_model` instead of `python train_model.py` (and can be done
anywhere once the package has been installed).
[\#32](https://github.com/mllam/neural-lam/pull/32), @leifdenby

## [v0.1.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.1.0)

First tagged release of `neural-lam`, matching Oskarsson et al 2023 publication
Expand Down
84 changes: 42 additions & 42 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ Still, some restrictions are inevitable:
## A note on the limited area setting
Currently we are using these models on a limited area covering the Nordic region, the so called MEPS area (see [paper](https://arxiv.org/abs/2309.17370)).
There are still some parts of the code that is quite specific for the MEPS area use case.
This is in particular true for the mesh graph creation (`create_mesh.py`) and some of the constants set in a `data_config.yaml` file (path specified in `train_model.py --data_config` ).
This is in particular true for the mesh graph creation (`python -m neural_lam.create_mesh`) and some of the constants set in a `data_config.yaml` file (path specified in `python -m neural_lam.train_model --data_config <data-config-filepath>` ).
If there is interest to use Neural-LAM for other areas it is not a substantial undertaking to refactor the code to be fully area-agnostic.
We would be happy to support such enhancements.
See the issues https://github.com/joeloskarsson/neural-lam/issues/2, https://github.com/joeloskarsson/neural-lam/issues/3 and https://github.com/joeloskarsson/neural-lam/issues/4 for some initial ideas on how this could be done.
Expand Down Expand Up @@ -96,39 +96,39 @@ See the [repository format section](#format-of-data-directory) for details on th
The full MEPS dataset can be shared with other researchers on request, contact us for this.
A tiny subset of the data (named `meps_example`) is available in `example_data.zip`, which can be downloaded from [here](https://liuonline-my.sharepoint.com/:f:/g/personal/joeos82_liu_se/EuiUuiGzFIFHruPWpfxfUmYBSjhqMUjNExlJi9W6ULMZ1w?e=97pnGX).
Download the file and unzip in the neural-lam directory.
All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `create_mesh.py`).
Note that this is far too little data to train any useful models, but all scripts can be ran with it.
All graphs used in the paper are also available for download at the same link (but can as easily be re-generated using `python -m neural_lam.create_mesh`).
Note that this is far too little data to train any useful models, but all pre-processing and training steps can be run with it.
It should thus be useful to make sure that your python environment is set up correctly and that all the code can be ran without any issues.

## Pre-processing
An overview of how the different scripts and files depend on each other is given in this figure:
An overview of how the different pre-processing steps, training and files depend on each other is given in this figure:
<p align="middle">
<img src="figures/component_dependencies.png"/>
</p>
In order to start training models at least three pre-processing scripts have to be ran:
In order to start training models at least three pre-processing steps have to be run:

* `create_mesh.py`
* `create_grid_features.py`
* `create_parameter_weights.py`
* `python -m neural_lam.create_mesh`
* `python -m neural_lam.create_grid_features`
* `python -m neural_lam.create_parameter_weights`

### Create graph
Run `create_mesh.py` with suitable options to generate the graph you want to use (see `python create_mesh.py --help` for a list of options).
Run `python -m neural_lam.create_mesh` with suitable options to generate the graph you want to use (see `python neural_lam.create_mesh --help` for a list of options).
The graphs used for the different models in the [paper](https://arxiv.org/abs/2309.17370) can be created as:

* **GC-LAM**: `python create_mesh.py --graph multiscale`
* **Hi-LAM**: `python create_mesh.py --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel)
* **L1-LAM**: `python create_mesh.py --graph 1level --levels 1`
* **GC-LAM**: `python -m neural_lam.create_mesh --graph multiscale`
* **Hi-LAM**: `python -m neural_lam.create_mesh --graph hierarchical --hierarchical 1` (also works for Hi-LAM-Parallel)
* **L1-LAM**: `python -m neural_lam.create_mesh --graph 1level --levels 1`

The graph-related files are stored in a directory called `graphs`.

### Create remaining static features
To create the remaining static files run the scripts `create_grid_features.py` and `create_parameter_weights.py`.
To create the remaining static files run `python -m neural_lam.create_grid_features` and `python -m neural_lam.create_parameter_weights`.

## Weights & Biases Integration
The project is fully integrated with [Weights & Biases](https://www.wandb.ai/) (W&B) for logging and visualization, but can just as easily be used without it.
When W&B is used, training configuration, training/test statistics and plots are sent to the W&B servers and made available in an interactive web interface.
If W&B is turned off, logging instead saves everything locally to a directory like `wandb/dryrun...`.
The W&B project name is set to `neural-lam`, but this can be changed in the flags of `train_model.py` (using argsparse).
The W&B project name is set to `neural-lam`, but this can be changed in the flags of `python -m neural_lam.train_model` (using argsparse).
See the [W&B documentation](https://docs.wandb.ai/) for details.

If you would like to login and use W&B, run:
Expand All @@ -141,8 +141,8 @@ wandb off
```

## Train Models
Models can be trained using `train_model.py`.
Run `python train_model.py --help` for a full list of training options.
Models can be trained using `python -m neural_lam.train_model`.
Run `python neural_lam.train_model --help` for a full list of training options.
A few of the key ones are outlined below:

* `--dataset`: Which data to train on
Expand All @@ -161,20 +161,20 @@ This model class is used both for the L1-LAM and GC-LAM models from the [paper](

To train 1L-LAM use
```
python train_model.py --model graph_lam --graph 1level ...
python -m neural_lam.train_model --model graph_lam --graph 1level ...
```

To train GC-LAM use
```
python train_model.py --model graph_lam --graph multiscale ...
python -m neural_lam.train_model --model graph_lam --graph multiscale ...
```

### Hi-LAM
A version of Graph-LAM that uses a hierarchical mesh graph and performs sequential message passing through the hierarchy during processing.

To train Hi-LAM use
```
python train_model.py --model hi_lam --graph hierarchical ...
python -m neural_lam.train_model --model hi_lam --graph hierarchical ...
```

### Hi-LAM-Parallel
Expand All @@ -183,13 +183,13 @@ Not included in the paper as initial experiments showed worse results than Hi-LA

To train Hi-LAM-Parallel use
```
python train_model.py --model hi_lam_parallel --graph hierarchical ...
python -m neural_lam.train_model --model hi_lam_parallel --graph hierarchical ...
```

Checkpoint files for our models trained on the MEPS data are available upon request.

## Evaluate Models
Evaluation is also done using `train_model.py`, but using the `--eval` option.
Evaluation is also done using `python -m neural_lam.train_model`, but using the `--eval` option.
Use `--eval val` to evaluate the model on the validation set and `--eval test` to evaluate on test data.
Most of the training options are also relevant for evaluation (not `ar_steps`, evaluation always unrolls full forecasts).
Some options specifically important for evaluation are:
Expand Down Expand Up @@ -232,13 +232,13 @@ data
│ ├── nwp_xy.npy - Coordinates of grid nodes (part of dataset)
│ ├── surface_geopotential.npy - Geopotential at surface of grid nodes (part of dataset)
│ ├── border_mask.npy - Mask with True for grid nodes that are part of border (part of dataset)
│ ├── grid_features.pt - Static features of grid nodes (create_grid_features.py)
│ ├── parameter_mean.pt - Means of state parameters (create_parameter_weights.py)
│ ├── parameter_std.pt - Std.-dev. of state parameters (create_parameter_weights.py)
│ ├── diff_mean.pt - Means of one-step differences (create_parameter_weights.py)
│ ├── diff_std.pt - Std.-dev. of one-step differences (create_parameter_weights.py)
│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (create_parameter_weights.py)
│ └── parameter_weights.npy - Loss weights for different state parameters (create_parameter_weights.py)
│ ├── grid_features.pt - Static features of grid nodes (neural_lam.create_grid_features)
│ ├── parameter_mean.pt - Means of state parameters (neural_lam.create_parameter_weights)
│ ├── parameter_std.pt - Std.-dev. of state parameters (neural_lam.create_parameter_weights)
│ ├── diff_mean.pt - Means of one-step differences (neural_lam.create_parameter_weights)
│ ├── diff_std.pt - Std.-dev. of one-step differences (neural_lam.create_parameter_weights)
│ ├── flux_stats.pt - Mean and std.-dev. of solar flux forcing (neural_lam.create_parameter_weights)
│ └── parameter_weights.npy - Loss weights for different state parameters (neural_lam.create_parameter_weights)
├── dataset2
├── ...
└── datasetN
Expand All @@ -250,13 +250,13 @@ The structure is shown with examples below:
```
graphs
├── graph1 - Directory with a graph definition
│ ├── m2m_edge_index.pt - Edges in mesh graph (create_mesh.py)
│ ├── g2m_edge_index.pt - Edges from grid to mesh (create_mesh.py)
│ ├── m2g_edge_index.pt - Edges from mesh to grid (create_mesh.py)
│ ├── m2m_features.pt - Static features of mesh edges (create_mesh.py)
│ ├── g2m_features.pt - Static features of grid to mesh edges (create_mesh.py)
│ ├── m2g_features.pt - Static features of mesh to grid edges (create_mesh.py)
│ └── mesh_features.pt - Static features of mesh nodes (create_mesh.py)
│ ├── m2m_edge_index.pt - Edges in mesh graph (neural_lam.create_mesh)
│ ├── g2m_edge_index.pt - Edges from grid to mesh (neural_lam.create_mesh)
│ ├── m2g_edge_index.pt - Edges from mesh to grid (neural_lam.create_mesh)
│ ├── m2m_features.pt - Static features of mesh edges (neural_lam.create_mesh)
│ ├── g2m_features.pt - Static features of grid to mesh edges (neural_lam.create_mesh)
│ ├── m2g_features.pt - Static features of mesh to grid edges (neural_lam.create_mesh)
│ └── mesh_features.pt - Static features of mesh nodes (neural_lam.create_mesh)
├── graph2
├── ...
└── graphN
Expand All @@ -266,9 +266,9 @@ graphs
To keep track of levels in the mesh graph, a list format is used for the files with mesh graph information.
In particular, the files
```
│ ├── m2m_edge_index.pt - Edges in mesh graph (create_mesh.py)
│ ├── m2m_features.pt - Static features of mesh edges (create_mesh.py)
│ ├── mesh_features.pt - Static features of mesh nodes (create_mesh.py)
│ ├── m2m_edge_index.pt - Edges in mesh graph (neural_lam.create_mesh)
│ ├── m2m_features.pt - Static features of mesh edges (neural_lam.create_mesh)
│ ├── mesh_features.pt - Static features of mesh nodes (neural_lam.create_mesh)
```
all contain lists of length `L`, for a hierarchical mesh graph with `L` layers.
For non-hierarchical graphs `L == 1` and these are all just singly-entry lists.
Expand All @@ -279,10 +279,10 @@ In addition, hierarchical mesh graphs (`L > 1`) feature a few additional files w
```
├── graph1
│ ├── ...
│ ├── mesh_down_edge_index.pt - Downward edges in mesh graph (create_mesh.py)
│ ├── mesh_up_edge_index.pt - Upward edges in mesh graph (create_mesh.py)
│ ├── mesh_down_features.pt - Static features of downward mesh edges (create_mesh.py)
│ ├── mesh_up_features.pt - Static features of upward mesh edges (create_mesh.py)
│ ├── mesh_down_edge_index.pt - Downward edges in mesh graph (neural_lam.create_mesh)
│ ├── mesh_up_edge_index.pt - Upward edges in mesh graph (neural_lam.create_mesh)
│ ├── mesh_down_features.pt - Static features of downward mesh edges (neural_lam.create_mesh)
│ ├── mesh_up_features.pt - Static features of upward mesh edges (neural_lam.create_mesh)
│ ├── ...
```
These files have the same list format as the ones above, but each list has length `L-1` (as these edges describe connections between levels).
Expand Down
10 changes: 10 additions & 0 deletions neural_lam/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# First-party
import neural_lam.config
import neural_lam.interaction_net
import neural_lam.metrics
import neural_lam.models
import neural_lam.utils
import neural_lam.vis

# Local
from .weather_dataset import WeatherDataset
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import numpy as np
import torch

# First-party
from neural_lam import config
# Local
from . import config


def main():
Expand Down
4 changes: 2 additions & 2 deletions create_mesh.py → neural_lam/create_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
import torch_geometric as pyg
from torch_geometric.utils.convert import from_networkx

# First-party
from neural_lam import config
# Local
from . import config


def plot_graph(graph, title=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm

# First-party
from neural_lam import config
from neural_lam.weather_dataset import WeatherDataset
# Local
from . import WeatherDataset, config


class PaddedWeatherDataset(torch.utils.data.Dataset):
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/interaction_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import torch_geometric as pyg
from torch import nn

# First-party
from neural_lam import utils
# Local
from . import utils


class InteractionNet(pyg.nn.MessagePassing):
Expand Down
6 changes: 6 additions & 0 deletions neural_lam/models/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Local
from .base_graph_model import BaseGraphModel
from .base_hi_graph_model import BaseHiGraphModel
from .graph_lam import GraphLAM
from .hi_lam import HiLAM
from .hi_lam_parallel import HiLAMParallel
4 changes: 2 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import torch
import wandb

# First-party
from neural_lam import config, metrics, utils, vis
# Local
from .. import config, metrics, utils, vis


class ARModel(pl.LightningModule):
Expand Down
8 changes: 4 additions & 4 deletions neural_lam/models/base_graph_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Third-party
import torch

# First-party
from neural_lam import utils
from neural_lam.interaction_net import InteractionNet
from neural_lam.models.ar_model import ARModel
# Local
from .. import utils
from ..interaction_net import InteractionNet
from .ar_model import ARModel


class BaseGraphModel(ARModel):
Expand Down
8 changes: 4 additions & 4 deletions neural_lam/models/base_hi_graph_model.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Third-party
from torch import nn

# First-party
from neural_lam import utils
from neural_lam.interaction_net import InteractionNet
from neural_lam.models.base_graph_model import BaseGraphModel
# Local
from .. import utils
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel


class BaseHiGraphModel(BaseGraphModel):
Expand Down
8 changes: 4 additions & 4 deletions neural_lam/models/graph_lam.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Third-party
import torch_geometric as pyg

# First-party
from neural_lam import utils
from neural_lam.interaction_net import InteractionNet
from neural_lam.models.base_graph_model import BaseGraphModel
# Local
from .. import utils
from ..interaction_net import InteractionNet
from .base_graph_model import BaseGraphModel


class GraphLAM(BaseGraphModel):
Expand Down
6 changes: 3 additions & 3 deletions neural_lam/models/hi_lam.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Third-party
from torch import nn

# First-party
from neural_lam.interaction_net import InteractionNet
from neural_lam.models.base_hi_graph_model import BaseHiGraphModel
# Local
from ..interaction_net import InteractionNet
from .base_hi_graph_model import BaseHiGraphModel


class HiLAM(BaseHiGraphModel):
Expand Down
6 changes: 3 additions & 3 deletions neural_lam/models/hi_lam_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import torch
import torch_geometric as pyg

# First-party
from neural_lam.interaction_net import InteractionNet
from neural_lam.models.base_hi_graph_model import BaseHiGraphModel
# Local
from ..interaction_net import InteractionNet
from .base_hi_graph_model import BaseHiGraphModel


class HiLAMParallel(BaseHiGraphModel):
Expand Down
9 changes: 3 additions & 6 deletions train_model.py → neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,9 @@
import torch
from lightning_fabric.utilities import seed

# First-party
from neural_lam import config, utils
from neural_lam.models.graph_lam import GraphLAM
from neural_lam.models.hi_lam import HiLAM
from neural_lam.models.hi_lam_parallel import HiLAMParallel
from neural_lam.weather_dataset import WeatherDataset
# Local
from . import WeatherDataset, config, utils
from .models import GraphLAM, HiLAM, HiLAMParallel

MODELS = {
"graph_lam": GraphLAM,
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
import matplotlib.pyplot as plt
import numpy as np

# First-party
from neural_lam import utils
# Local
from . import utils


@matplotlib.rc_context(utils.fractional_plot_bundle(1))
Expand Down
4 changes: 2 additions & 2 deletions neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
import numpy as np
import torch

# First-party
from neural_lam import utils
# Local
from . import utils


class WeatherDataset(torch.utils.data.Dataset):
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ dev = [
"pooch>=1.8.1",
]

[tool.setuptools]
py-modules = ["neural_lam"]

[tool.black]
line-length = 80

Expand Down
Empty file removed tests/__init__.py
Empty file.
Loading

0 comments on commit 41364a8

Please sign in to comment.