Skip to content

Commit

Permalink
Merge branch 'main' into comparison_gifs
Browse files Browse the repository at this point in the history
  • Loading branch information
LBerth committed Sep 18, 2024
2 parents 2528e0d + 20e3bb8 commit 61f0625
Show file tree
Hide file tree
Showing 12 changed files with 343 additions and 92 deletions.
39 changes: 0 additions & 39 deletions .gitlab-ci.yml

This file was deleted.

11 changes: 8 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,23 +304,28 @@ You can find more details about all the `num_X_steps` options [here](doc/num_ste

### Inference

Inference is done by running the `bin/inference.py` script. This script will load a model and run it on a dataset using the training parameters (dataset name, dataset config, timestep options, ...).
Inference is done by running the `bin/inference.py` script. This script will load a model and run it on a dataset using the training parameters (dataset config, timestep options, ...).

```bash
usage: py4cast Inference script [-h] [--model_path MODEL_PATH] [--dataset DATASET] [--ds_config_file DS_CONFIG_FILE] [--date DATE]
usage: py4cast Inference script [-h] [--model_path MODEL_PATH] [--dataset DATASET] [--infer_steps INFER_STEPS] [--date DATE]

options:
-h, --help show this help message and exit
--model_path MODEL_PATH
Path to the model checkpoint
--date DATE
Date of the sample to infer on. Format:YYYYMMDDHH
--dataset DATASET
Name of the dataset config file to use
--infer_steps INFER_STEPS
Number of auto-regressive steps/prediction steps during the inference

```

A simple example of inference is shown below:

```bash
runai exec_gpu python bin/inference.py --model_path /scratch/shared/py4cast/logs/comparison/titan/swinunetr/bert_swinunetr_0/last.ckpt --date 2023123123
runai exec_gpu python bin/inference.py --model_path /scratch/shared/py4cast/logs/camp0/poesy/halfunet/sezn_run_dev_9/last.ckpt --date 2021061621 --dataset poesy_infer --infer_steps 2

```

Expand Down
25 changes: 17 additions & 8 deletions bin/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,30 @@
parser = argparse.ArgumentParser("py4cast Inference script")
parser.add_argument("--model_path", type=str, help="Path to the model checkpoint")
parser.add_argument("--date", type=str, help="Date for inference", default=None)
parser.add_argument(
"--dataset", type=str, help="Dataset used in inference", default="poesy_infer"
)
parser.add_argument(
"--infer_steps", type=int, help="Number of inference steps", default=1
)

args = parser.parse_args()

# Load checkpoint
lightning_module = AutoRegressiveLightning.load_from_checkpoint(args.model_path)
hparams = lightning_module.hparams["hparams"]

if args.date is not None:
config_override = {"periods": {"test": {"start": args.date, "end": args.date}}}
config_override = {
"periods": {"test": {"start": args.date, "end": args.date}},
"num_inference_pred_steps": args.infer_steps,
}
else:
config_override = None

hparams = lightning_module.hparams["hparams"]
config_override = {"num_inference_pred_steps": args.infer_steps}

# Get dataset for inference
_, _, test_ds = get_datasets(
hparams.dataset_name,
_, _, infer_ds = get_datasets(
args.dataset,
hparams.num_input_steps,
hparams.num_pred_steps_train,
hparams.num_pred_steps_val_test,
Expand All @@ -36,7 +44,8 @@
)

# Transform in dataloader
dl_settings = TorchDataloaderSettings(batch_size=2)
infer_loader = test_ds.torch_dataloader(dl_settings)

dl_settings = TorchDataloaderSettings(batch_size=1)
infer_loader = infer_ds.torch_dataloader(dl_settings)
trainer = Trainer(devices="auto")
preds = trainer.predict(lightning_module, infer_loader)
65 changes: 65 additions & 0 deletions config/datasets/poesy_infer.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"periods": {
"train": {
"start": 2020061521,
"end": 2021061521,
"step": 24
},
"test": {
"start": 2021061621,
"end": 2021071621,
"step": 24
},
"valid": {
"start": 2021071721,
"end": 2021111521,
"step": 24
}
},
"grid":{
"geometry":"EURW1S40",
"border_size":10,
"domain":"france",
"model":"arome",
"subgrid":[50,178,50,178]
},
"settings":{
"step_duration": 1,
"standardize": true,
"file_format": "npy"
},
"dataset": {
"arome": {
"members": [
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16
],
"term": {
"start": 0,
"end": 44,
"timestep": 1
},
"var": {
"TEMPERATURE_2m": {"shortname":"t2m", "level": [1]},
"PRECIPITATIONS_DECUM":{"shortname":"rrdecum","level": [1]},
"WIND.U.PHYS":{"shortname":"u","level": [1]},
"WIND.V.PHYS":{"shortname":"v","level": [1]}
}
}
}
}
7 changes: 7 additions & 0 deletions config/models/unetrpp8512_linear_up_d2.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"num_heads": 8,
"hidden_size": 512,
"linear_upsampling": true,
"downsampling_rate": 2
}

2 changes: 1 addition & 1 deletion doc/features.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ Currently we support the following neural network architectures:
| [segformer](../py4cast/models/vision/transformers.py#L216) | [arxiv link](https://arxiv.org/abs/2105.15203) | (Batch, Height, Width, features) | On par with u-net like on Deepsyg (MF internal), added an upsampling stage. Adapted from [Lucidrains' github](https://github.com/lucidrains/segformer-pytorch) | Frank Guibert |
| [swinunetr](../py4cast/models/vision/transformers.py#L335) | [arxiv link](https://arxiv.org/abs/2201.01266) | (Batch, Height, Width, features) | 2D Swin Unet transformer (Pangu and archweather uses customised 3D versions of Swin Transformers). Plugged in from [MONAI](https://github.com/Project-MONAI/MONAI/). The decoders have been modified to use Bilinear2D + Conv2d instead of Conv2dTranspose to remove artefacts/checkerboard effects | Frank Guibert |
| [hilam](../py4cast/models/nlam/models.py#L754), graphlam | [arxiv link](https://arxiv.org/abs/2309.17370) | (Batch, graph_node_id, features) | Imported and adapted from [Joel's github](https://github.com/joeloskarsson/neural-lam) | Vincent Chabot/Frank Guibert |
| [unetrpp](../py4cast/models/vision/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs | Frank Guibert |
| [unetrpp](../py4cast/models/vision/unetrpp.py#L1) | [arxiv link](https://arxiv.org/abs/2212.04497) | (Batch, features, Height, Width) or (Batch, features, Height, Width, Depth) | Vision transformer with a reduced GFLOPS footprint adapted from [author's github](https://github.com/Amshaker/unetr_plus_plus). Modified to work both with 2d and 3d inputs. Changed Upsampling to use linear upsampling. Made stem layer downsampling rate a parameter. | Frank Guibert |

## Available datasets

Expand Down
10 changes: 10 additions & 0 deletions py4cast/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,16 @@
except ImportError:
warnings.warn(f"Could not import PoesyDataset. {traceback.format_exc()}")

try:
from .poesy import InferPoesyDataset

registry["poesy_infer"] = (
InferPoesyDataset,
default_config_root / "poesy_infer.json",
)
except ImportError:
warnings.warn(f"Could not import InferPoesyDataset. {traceback.format_exc()}")

try:
from .dummy import DummyDataset

Expand Down
2 changes: 1 addition & 1 deletion py4cast/datasets/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ class Item:

inputs: NamedTensor
forcing: NamedTensor
outputs: Union[None, NamedTensor] = None
outputs: NamedTensor

def __post_init__(self):
"""
Expand Down
Loading

0 comments on commit 61f0625

Please sign in to comment.