Skip to content

Commit

Permalink
docs + clean
Browse files Browse the repository at this point in the history
  • Loading branch information
LBerth committed Sep 19, 2024
1 parent 61f0625 commit e9d5d0d
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 63 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ Contributions are welcome (Issues, Pull Requests, ...).

This project is licensed under the [APACHE 2.0 license.](LICENSE-2.0.txt)

![Forecast humidity](figs/2023061812_aro_r2_2m_crop.gif)
![Forecast precip](figs/2023061812_aro_tp_0m.gif)

# Acknowledgements

This project started as a fork of neural-lam, a project by Joel Oskarsson, see [here](https://github.com/mllam/neural-lam). Many thanks to Joel for his work!
Expand Down
128 changes: 65 additions & 63 deletions bin/gif_comparison.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
"""Plots animations comparing forecasts of multiple models with ground truth.
/!\ - For now this script only works with models trained with Titan dataset.
- If you want to use AROME as a model, you have to manually download the forecast before.
usage: gif_comparison.py [-h] --ckpt CKPT --date DATE [--num_pred_steps NUM_PRED_STEPS]
options:
-h, --help show this help message and exit
--ckpt CKPT Paths to the model checkpoint or AROME
--date DATE Date for inference. Format YYYYMMDDHH.
--num_pred_steps NUM_PRED_STEPS
Number of auto-regressive steps/prediction steps.
example: python bin/gif_comparison.py --ckpt AROME --ckpt /.../logs/my_run/epoch=247.ckpt --date 2023061812 --num_pred_steps 10"""

import argparse


Expand All @@ -18,48 +33,27 @@
from skimage.transform import resize
import xarray as xr

from py4cast.datasets.titan.settings import METADATA

BASE_PATH = Path("/scratch/shared/Titan/AROME/")
from py4cast.datasets.titan.settings import METADATA, AROME_PATH

COLORMAPS = {
"t2m": {
"cmap": "Spectral_r",
"vmin": 240,
"vmax": 320
},
"r2": {
"cmap": "Spectral",
"vmin": 0,
"vmax": 100
},
"tp": {
"cmap": "Spectral_r",
"vmin": 0.5,
"vmax": 100
},
"u10": {
"cmap": "RdBu",
"vmin": -20,
"vmax": 20
},
"v10": {
"cmap": "RdBu",
"vmin": -20,
"vmax": 20
}
"t2m": {"cmap": "Spectral_r", "vmin": 240, "vmax": 320},
"r2": {"cmap": "Spectral", "vmin": 0, "vmax": 100},
"tp": {"cmap": "Spectral_r", "vmin": 0.5, "vmax": 100},
"u10": {"cmap": "RdBu","vmin": -20, "vmax": 20},
"v10": {"cmap": "RdBu", "vmin": -20, "vmax": 20}
}


def downscale(array):
subdomain=[100, 612, 240, 880]
grid_info = METADATA["GRIDS"]["PAAROME_1S40"]
def downscale(array:np.ndarray, grid: str = "PAAROME_1S40", domain: Tuple[int]=[100, 612, 240, 880]) -> np.ndarray:
"""Downscales an array from Titan grid 1S100 to another grid and subdomain."""
grid_info = METADATA["GRIDS"][grid]
array = resize(array, grid_info["size"], anti_aliasing=True)
array = array[subdomain[0] : subdomain[1], subdomain[2] : subdomain[3]]
array = array[domain[0] : domain[1], domain[2] : domain[3]]
return array


def get_param(path:Path, param:str) -> np.ndarray:
"""Extracts a weather param from an AROME forecast in grib."""
ds = xr.open_dataset(path, engine="cfgrib")
array = ds[param].values
arr_list = [downscale(array[t]) for t in range(array.shape[0])]
Expand All @@ -68,13 +62,15 @@ def get_param(path:Path, param:str) -> np.ndarray:


def post_process_tp_arome(array:np.ndarray) -> np.ndarray:
"""Converts AROME precip forecast in mm/h.
By default, AROME accumulates mm starting from t0."""
diff_arrs = [array[t+1] - array[t] for t in range(12)]
return np.stack(diff_arrs)


def read_arome(date:str) -> Tuple[np.ndarray]:
path = BASE_PATH / date
print(path)
"""Extracts 5 parameters (t2m, r2, tp, u10, v10) of an AROME forecast."""
path = AROME_PATH / date
r2 = get_param(path / "AROME_1S100_ECH0_2M.grib", "r2")
t2m = get_param(path / "AROME_1S100_ECH0_2M.grib", "t2m")
u10 = get_param(path / "AROME_1S100_ECH0_10M.grib", "u10")
Expand All @@ -85,6 +81,7 @@ def read_arome(date:str) -> Tuple[np.ndarray]:


def get_model_and_hparams(ckpt: Path, num_pred_steps:int) -> Tuple[AutoRegressiveLightning, ArLightningHyperParam]:
"""Loads a model from its checkpoint and changes the nb of forecast steps."""
model = AutoRegressiveLightning.load_from_checkpoint(ckpt)
hparams = model.hparams["hparams"]
hparams.num_pred_steps_val_test = num_pred_steps
Expand All @@ -93,7 +90,7 @@ def get_model_and_hparams(ckpt: Path, num_pred_steps:int) -> Tuple[AutoRegressiv


def get_item_for_date(date:str, hparams: ArLightningHyperParam) -> Item:
""" Returns Item containing sample of chosen date.
"""Returns an Item containing one sample for a chosen date.
Date should be in format YYYYMMDDHH.
"""
config_override = {"periods": {"test": {"start": date, "end": date}}}
Expand All @@ -110,6 +107,7 @@ def get_item_for_date(date:str, hparams: ArLightningHyperParam) -> Item:


def make_forecast(model: AutoRegressiveLightning, item: Item) -> torch.tensor:
"""Applies a model an Item to make a forecast."""
batch_item = collate_fn([item])
preds = model(batch_item)
forecast = preds.tensor
Expand All @@ -122,6 +120,7 @@ def make_forecast(model: AutoRegressiveLightning, item: Item) -> torch.tensor:


def post_process_outputs(y: torch.tensor, feature_names:List[str], feature_names_to_idx:dict) -> np.ndarray:
"""Post-processes one forecast by de-normalizing the values of each feature."""
arrays = []
for feature_name in feature_names:
idx_feature = feature_names_to_idx[feature_name]
Expand All @@ -134,7 +133,8 @@ def post_process_outputs(y: torch.tensor, feature_names:List[str], feature_names

@gif.frame
def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarray], domain_info: DomainInfo,
title=None, models_names=None, unit:str=None, vmin:float=None, vmax:float=None)-> None:
title:str=None, models_names:List[str]=None, colorbar_label:str=None, vmin:float=None, vmax:float=None)-> None:
"""Plots one frame of the animation."""

nb_preds = len(predictions) + 1
lines = int(math.sqrt(nb_preds))
Expand All @@ -157,7 +157,6 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr

for i, data in enumerate([target] + predictions):
axes[i].coastlines()
# array = data
if param == "tp": # precipitations
data = np.where(data < 0.5, np.nan, data)
im = axes[i].imshow(
Expand All @@ -166,15 +165,36 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr
if models_names:
axes[i].set_title(models_names[i], size=15)

subfig.colorbar(im, ax=axes, location='bottom', label=unit, aspect=40)
subfig.colorbar(im, ax=axes, location='bottom', label=colorbar_label, aspect=40)

if title:
fig.suptitle(title, size=20)

copyright = "Météo-France, Py4cast project."
fig.text(0, 0.02, copyright, fontsize=8, ha="left")


def make_gif(feature: str, date:str, target: np.ndarray, preds: List[np.ndarray], models_names: List[str], domain_info: DomainInfo):
"""Plots a gifs comparing multiple forecasts of one feature."""
vmin, vmax = target.min(), target.max()
date = dt.datetime.strptime(date, "%Y%m%d%H")
date_str = date.strftime("%Y-%m-%d %Hh UTC")
short_name = "_".join(feature.split("_")[:2])
feature_str = METADATA["WEATHER_PARAMS"][short_name]["long_name"][6:]
unit = f"{feature_str} ({hparams.dataset_info.units[feature]})"

frames = []
for t in trange(target.shape[0]):
title = f"{date_str} +{t+1}h"
preds_t = [pred[t] for pred in preds]
frame = plot_frame(feature, target[t], preds_t, domain_info, title, models_names, unit, vmin, vmax)
frames.append(frame)
gif.save(frames, f"{args.date}_{feature}.gif", duration=500)


if __name__ == "__main__":

parser = argparse.ArgumentParser("py4cast Inference script")
parser = argparse.ArgumentParser("Plot animations")
parser.add_argument("--ckpt", type=str, action='append', help="Paths to the model checkpoint or AROME", required=True)
parser.add_argument("--date", type=str, help="Date for inference. Format YYYYMMDDHH.", required=True)
parser.add_argument(
Expand All @@ -193,17 +213,18 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr
if ckpt == "AROME":
t2m, r2, tp, u10, v10 = read_arome(args.date)
forecast = np.stack([t2m, r2, tp, u10, v10], axis=-1)
models_names.append("AROME")
models_names.append("AROME Oper")
else:
model, hparams = get_model_and_hparams(ckpt, args.num_pred_steps)
item = get_item_for_date(args.date, hparams)
forecast = make_forecast(model, item)
forecast = post_process_outputs(forecast, feature_names, item.inputs.feature_names_to_idx)
feature_idx_dict = item.inputs.feature_names_to_idx
forecast = post_process_outputs(forecast, feature_names, feature_idx_dict)
models_names.append(f"{hparams.model_name}\n{hparams.save_path.name}")
y_preds.append(forecast)

y_true = item.outputs.tensor
y_true = post_process_outputs(y_true, feature_names, item.inputs.feature_names_to_idx)
y_true = post_process_outputs(y_true, feature_names, feature_idx_dict)
domain_info = hparams.dataset_info.domain_info
models_names = ["AROME Analysis"] + models_names

Expand All @@ -213,26 +234,7 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr

for feature_name in feature_names:
print(feature_name)

idx_feature = item.inputs.feature_names_to_idx[feature_name]
idx_feature = feature_idx_dict[feature_name]
target_feat = y_true[:,:,:, idx_feature]
list_preds_feat = [pred[:,:,:,idx_feature] for pred in y_preds]

vmin, vmax = target_feat.min(), target_feat.max()
date = dt.datetime.strptime(args.date, "%Y%m%d%H")
date_str = date.strftime("%Y-%m-%d %Hh UTC")
short_name = "_".join(feature_name.split("_")[:2])
feature_str = METADATA["WEATHER_PARAMS"][short_name]["long_name"][6:]
unit = f"{feature_str} ({hparams.dataset_info.units[feature_name]})"

frames = []
for t in trange(args.num_pred_steps):
title = f"{date_str} +{t+1}h"
target = target_feat[t]
list_preds = [pred[t] for pred in list_preds_feat]
frame = plot_frame(feature_name, target, list_preds, domain_info, title, models_names, unit, vmin, vmax)
frames.append(frame)
gif.save(frames, f"{args.date}_{feature_name}.gif", duration=250)

# TODO :
# - update README with script usage + gifs in main README
make_gif(feature_name, args.date, target_feat, list_preds_feat, models_names, domain_info)
1 change: 1 addition & 0 deletions py4cast/datasets/titan/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import yaml

SCRATCH_PATH = Path(os.environ.get("PY4CAST_TITAN_PATH", "/scratch/shared/Titan"))
AROME_PATH = SCRATCH_PATH / "AROME"
FORMATSTR = "%Y-%m-%d_%Hh%M"

with open(Path(__file__).parents[0] / "metadata.yaml", "r") as file:
Expand Down

0 comments on commit e9d5d0d

Please sign in to comment.