diff --git a/README.md b/README.md index b68b0b38..e1bd760f 100644 --- a/README.md +++ b/README.md @@ -10,6 +10,9 @@ Contributions are welcome (Issues, Pull Requests, ...). This project is licensed under the [APACHE 2.0 license.](LICENSE-2.0.txt) +![Forecast humidity](figs/2023061812_aro_r2_2m_crop.gif) +![Forecast precip](figs/2023061812_aro_tp_0m.gif) + # Acknowledgements This project started as a fork of neural-lam, a project by Joel Oskarsson, see [here](https://github.com/mllam/neural-lam). Many thanks to Joel for his work! diff --git a/bin/gif_comparison.py b/bin/gif_comparison.py index c7ec423d..908036aa 100644 --- a/bin/gif_comparison.py +++ b/bin/gif_comparison.py @@ -1,3 +1,18 @@ +"""Plots animations comparing forecasts of multiple models with ground truth. +/!\ - For now this script only works with models trained with Titan dataset. + - If you want to use AROME as a model, you have to manually download the forecast before. + +usage: gif_comparison.py [-h] --ckpt CKPT --date DATE [--num_pred_steps NUM_PRED_STEPS] + +options: + -h, --help show this help message and exit + --ckpt CKPT Paths to the model checkpoint or AROME + --date DATE Date for inference. Format YYYYMMDDHH. + --num_pred_steps NUM_PRED_STEPS + Number of auto-regressive steps/prediction steps. + +example: python bin/gif_comparison.py --ckpt AROME --ckpt /.../logs/my_run/epoch=247.ckpt --date 2023061812 --num_pred_steps 10""" + import argparse @@ -18,48 +33,27 @@ from skimage.transform import resize import xarray as xr -from py4cast.datasets.titan.settings import METADATA - -BASE_PATH = Path("/scratch/shared/Titan/AROME/") +from py4cast.datasets.titan.settings import METADATA, AROME_PATH COLORMAPS = { - "t2m": { - "cmap": "Spectral_r", - "vmin": 240, - "vmax": 320 - }, - "r2": { - "cmap": "Spectral", - "vmin": 0, - "vmax": 100 - }, - "tp": { - "cmap": "Spectral_r", - "vmin": 0.5, - "vmax": 100 - }, - "u10": { - "cmap": "RdBu", - "vmin": -20, - "vmax": 20 - }, - "v10": { - "cmap": "RdBu", - "vmin": -20, - "vmax": 20 - } + "t2m": {"cmap": "Spectral_r", "vmin": 240, "vmax": 320}, + "r2": {"cmap": "Spectral", "vmin": 0, "vmax": 100}, + "tp": {"cmap": "Spectral_r", "vmin": 0.5, "vmax": 100}, + "u10": {"cmap": "RdBu","vmin": -20, "vmax": 20}, + "v10": {"cmap": "RdBu", "vmin": -20, "vmax": 20} } -def downscale(array): - subdomain=[100, 612, 240, 880] - grid_info = METADATA["GRIDS"]["PAAROME_1S40"] +def downscale(array:np.ndarray, grid: str = "PAAROME_1S40", domain: Tuple[int]=[100, 612, 240, 880]) -> np.ndarray: + """Downscales an array from Titan grid 1S100 to another grid and subdomain.""" + grid_info = METADATA["GRIDS"][grid] array = resize(array, grid_info["size"], anti_aliasing=True) - array = array[subdomain[0] : subdomain[1], subdomain[2] : subdomain[3]] + array = array[domain[0] : domain[1], domain[2] : domain[3]] return array def get_param(path:Path, param:str) -> np.ndarray: + """Extracts a weather param from an AROME forecast in grib.""" ds = xr.open_dataset(path, engine="cfgrib") array = ds[param].values arr_list = [downscale(array[t]) for t in range(array.shape[0])] @@ -68,13 +62,15 @@ def get_param(path:Path, param:str) -> np.ndarray: def post_process_tp_arome(array:np.ndarray) -> np.ndarray: + """Converts AROME precip forecast in mm/h. + By default, AROME accumulates mm starting from t0.""" diff_arrs = [array[t+1] - array[t] for t in range(12)] return np.stack(diff_arrs) def read_arome(date:str) -> Tuple[np.ndarray]: - path = BASE_PATH / date - print(path) + """Extracts 5 parameters (t2m, r2, tp, u10, v10) of an AROME forecast.""" + path = AROME_PATH / date r2 = get_param(path / "AROME_1S100_ECH0_2M.grib", "r2") t2m = get_param(path / "AROME_1S100_ECH0_2M.grib", "t2m") u10 = get_param(path / "AROME_1S100_ECH0_10M.grib", "u10") @@ -85,6 +81,7 @@ def read_arome(date:str) -> Tuple[np.ndarray]: def get_model_and_hparams(ckpt: Path, num_pred_steps:int) -> Tuple[AutoRegressiveLightning, ArLightningHyperParam]: + """Loads a model from its checkpoint and changes the nb of forecast steps.""" model = AutoRegressiveLightning.load_from_checkpoint(ckpt) hparams = model.hparams["hparams"] hparams.num_pred_steps_val_test = num_pred_steps @@ -93,7 +90,7 @@ def get_model_and_hparams(ckpt: Path, num_pred_steps:int) -> Tuple[AutoRegressiv def get_item_for_date(date:str, hparams: ArLightningHyperParam) -> Item: - """ Returns Item containing sample of chosen date. + """Returns an Item containing one sample for a chosen date. Date should be in format YYYYMMDDHH. """ config_override = {"periods": {"test": {"start": date, "end": date}}} @@ -110,6 +107,7 @@ def get_item_for_date(date:str, hparams: ArLightningHyperParam) -> Item: def make_forecast(model: AutoRegressiveLightning, item: Item) -> torch.tensor: + """Applies a model an Item to make a forecast.""" batch_item = collate_fn([item]) preds = model(batch_item) forecast = preds.tensor @@ -122,6 +120,7 @@ def make_forecast(model: AutoRegressiveLightning, item: Item) -> torch.tensor: def post_process_outputs(y: torch.tensor, feature_names:List[str], feature_names_to_idx:dict) -> np.ndarray: + """Post-processes one forecast by de-normalizing the values of each feature.""" arrays = [] for feature_name in feature_names: idx_feature = feature_names_to_idx[feature_name] @@ -134,7 +133,8 @@ def post_process_outputs(y: torch.tensor, feature_names:List[str], feature_names @gif.frame def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarray], domain_info: DomainInfo, - title=None, models_names=None, unit:str=None, vmin:float=None, vmax:float=None)-> None: + title:str=None, models_names:List[str]=None, colorbar_label:str=None, vmin:float=None, vmax:float=None)-> None: + """Plots one frame of the animation.""" nb_preds = len(predictions) + 1 lines = int(math.sqrt(nb_preds)) @@ -157,7 +157,6 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr for i, data in enumerate([target] + predictions): axes[i].coastlines() - # array = data if param == "tp": # precipitations data = np.where(data < 0.5, np.nan, data) im = axes[i].imshow( @@ -166,15 +165,36 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr if models_names: axes[i].set_title(models_names[i], size=15) - subfig.colorbar(im, ax=axes, location='bottom', label=unit, aspect=40) + subfig.colorbar(im, ax=axes, location='bottom', label=colorbar_label, aspect=40) if title: fig.suptitle(title, size=20) + copyright = "Météo-France, Py4cast project." + fig.text(0, 0.02, copyright, fontsize=8, ha="left") + + +def make_gif(feature: str, date:str, target: np.ndarray, preds: List[np.ndarray], models_names: List[str], domain_info: DomainInfo): + """Plots a gifs comparing multiple forecasts of one feature.""" + vmin, vmax = target.min(), target.max() + date = dt.datetime.strptime(date, "%Y%m%d%H") + date_str = date.strftime("%Y-%m-%d %Hh UTC") + short_name = "_".join(feature.split("_")[:2]) + feature_str = METADATA["WEATHER_PARAMS"][short_name]["long_name"][6:] + unit = f"{feature_str} ({hparams.dataset_info.units[feature]})" + + frames = [] + for t in trange(target.shape[0]): + title = f"{date_str} +{t+1}h" + preds_t = [pred[t] for pred in preds] + frame = plot_frame(feature, target[t], preds_t, domain_info, title, models_names, unit, vmin, vmax) + frames.append(frame) + gif.save(frames, f"{args.date}_{feature}.gif", duration=500) + if __name__ == "__main__": - parser = argparse.ArgumentParser("py4cast Inference script") + parser = argparse.ArgumentParser("Plot animations") parser.add_argument("--ckpt", type=str, action='append', help="Paths to the model checkpoint or AROME", required=True) parser.add_argument("--date", type=str, help="Date for inference. Format YYYYMMDDHH.", required=True) parser.add_argument( @@ -193,17 +213,18 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr if ckpt == "AROME": t2m, r2, tp, u10, v10 = read_arome(args.date) forecast = np.stack([t2m, r2, tp, u10, v10], axis=-1) - models_names.append("AROME") + models_names.append("AROME Oper") else: model, hparams = get_model_and_hparams(ckpt, args.num_pred_steps) item = get_item_for_date(args.date, hparams) forecast = make_forecast(model, item) - forecast = post_process_outputs(forecast, feature_names, item.inputs.feature_names_to_idx) + feature_idx_dict = item.inputs.feature_names_to_idx + forecast = post_process_outputs(forecast, feature_names, feature_idx_dict) models_names.append(f"{hparams.model_name}\n{hparams.save_path.name}") y_preds.append(forecast) y_true = item.outputs.tensor - y_true = post_process_outputs(y_true, feature_names, item.inputs.feature_names_to_idx) + y_true = post_process_outputs(y_true, feature_names, feature_idx_dict) domain_info = hparams.dataset_info.domain_info models_names = ["AROME Analysis"] + models_names @@ -213,26 +234,7 @@ def plot_frame(feature_name: str, target: np.ndarray, predictions: List[np.ndarr for feature_name in feature_names: print(feature_name) - - idx_feature = item.inputs.feature_names_to_idx[feature_name] + idx_feature = feature_idx_dict[feature_name] target_feat = y_true[:,:,:, idx_feature] list_preds_feat = [pred[:,:,:,idx_feature] for pred in y_preds] - - vmin, vmax = target_feat.min(), target_feat.max() - date = dt.datetime.strptime(args.date, "%Y%m%d%H") - date_str = date.strftime("%Y-%m-%d %Hh UTC") - short_name = "_".join(feature_name.split("_")[:2]) - feature_str = METADATA["WEATHER_PARAMS"][short_name]["long_name"][6:] - unit = f"{feature_str} ({hparams.dataset_info.units[feature_name]})" - - frames = [] - for t in trange(args.num_pred_steps): - title = f"{date_str} +{t+1}h" - target = target_feat[t] - list_preds = [pred[t] for pred in list_preds_feat] - frame = plot_frame(feature_name, target, list_preds, domain_info, title, models_names, unit, vmin, vmax) - frames.append(frame) - gif.save(frames, f"{args.date}_{feature_name}.gif", duration=250) - -# TODO : -# - update README with script usage + gifs in main README \ No newline at end of file + make_gif(feature_name, args.date, target_feat, list_preds_feat, models_names, domain_info) diff --git a/py4cast/datasets/titan/settings.py b/py4cast/datasets/titan/settings.py index dbd01f04..019cb6e9 100644 --- a/py4cast/datasets/titan/settings.py +++ b/py4cast/datasets/titan/settings.py @@ -4,6 +4,7 @@ import yaml SCRATCH_PATH = Path(os.environ.get("PY4CAST_TITAN_PATH", "/scratch/shared/Titan")) +AROME_PATH = SCRATCH_PATH / "AROME" FORMATSTR = "%Y-%m-%d_%Hh%M" with open(Path(__file__).parents[0] / "metadata.yaml", "r") as file: