Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GRIB output #17

Open
wants to merge 28 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
description: Check for spelling errors
language: system
entry: codespell
args: ['--ignore-words-list=laf']
args: ['--ignore-words-list=laf,pres']
- repo: local
hooks:
- id: black
Expand Down
2 changes: 1 addition & 1 deletion create_parameter_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def main(rank, world_size): # pylint: disable=redefined-outer-name
batch_size=args.batch_size,
num_workers=args.n_workers,
)
data_module.setup(stage="fit")
data_module.setup(stage="train")

train_sampler = DistributedSampler(
data_module.train_dataset, num_replicas=world_size, rank=rank
Expand Down
2 changes: 2 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ dependencies:
- Cartopy
- dask
- dask-jobqueue
- eccodes
- imageio
- ipython
- matplotlib
Expand All @@ -29,6 +30,7 @@ dependencies:
- xarray
- zarr
- pip:
- earthkit-data
- tueplots
- codespell>=2.0.0
- black>=21.9b0
Expand Down
23 changes: 23 additions & 0 deletions neural_lam/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,24 @@
"V_10M": 0,
}

GRIB_NAME = {
"PP": "pres",
"QV": "q",
"RELHUM": "r",
"T": "t",
"U": "u",
"V": "v",
"W": "wz",
"CLCT": "ccl",
"PMSL": "prmsl",
"PS": "sp",
"T_2M": "2t",
"TOT_PREC": "tp",
"U_10M": "10u",
"V_10M": "10v",
}


# Vertical level weights
# These were retrieved based on the pressure levels of
# https://weatherbench2.readthedocs.io/en/latest/data-guide.html#era5
Expand Down Expand Up @@ -183,6 +201,11 @@
EVAL_PLOT_VARS = ["T_2M"]
STORE_EXAMPLE_DATA = True
SELECTED_PROJ = ccrs.PlateCarree()
SAMPLE_GRIB = "templates/lfff02180000"
SAMPLE_Z_GRIB = "templates/lfff02180000z"
EVAL_DATETIME = ["2020100215"]
POLLON = -170.0
POLLAT = 43.0
SMOOTH_BOUNDARIES = False

# Some constants useful for sub-classes 3 fluxes variables + 4 time-related
Expand Down
92 changes: 92 additions & 0 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from datetime import datetime, timedelta

# Third-party
import earthkit.data
import imageio
import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -823,6 +824,7 @@ def on_predict_epoch_end(self):
prediction_array = prediction_rescaled.cpu().numpy()
file_path = os.path.join(value_dir_path, f"prediction_{i}.npy")
np.save(file_path, prediction_array)
self.save_pred_as_grib(file_path, value_dir_path)

# For plots
for var_name, _ in self.selected_vars_units:
Expand All @@ -847,6 +849,96 @@ def on_predict_epoch_end(self):
for filename in images:
image = imageio.imread(filename)
writer.append_data(image)
self.spatial_loss_maps.clear()

def _generate_time_steps(self):
clechartre marked this conversation as resolved.
Show resolved Hide resolved
"""Generate a list with all time steps in inference."""
# Parse the times
base_time = constants.EVAL_DATETIMES[0]
if isinstance(base_time, str):
base_time = datetime.strptime(base_time, "%Y%m%d%H")
time_steps = {}
# Generate dates for each step
for i in range(constants.EVAL_HORIZON - 2):
# Compute the new date by adding the step interval in hours - 3
new_date = base_time + timedelta(hours=i * constants.TRAIN_HORIZON)
# Format the date back
time_steps[i] = new_date.strftime("%Y%m%d%H")

return time_steps

def save_pred_as_grib(self, file_path, value_dir_path):
"""Save the prediction values into GRIB format."""
# Initialize the lists to loop over
indices = self.precompute_variable_indices()
time_steps = self._generate_time_steps()
# Loop through all the time steps and all the variables
for time_idx, date_str in time_steps.items():
# Initialize final data object
final_data = earthkit.data.FieldList()
for variable, grib_code in constants.GRIB_NAME.items():
# here find the key of the cariable in constants.is_3D
# and if == 7, assign a cut of 7 on the reshape. Else 1
if constants.IS_3D[variable]:
shape_val = len(constants.VERTICAL_LEVELS)
vertical = constants.VERTICAL_LEVELS
else:
# Special handling for T_2M and *_10M variables
if variable == "T_2M":
shape_val = 1
vertical = 2
elif variable.endswith("_10M"):
shape_val = 1
vertical = 10
else:
shape_val = 1
vertical = 0
# Find the value range to sample
value_range = indices[variable]

sample_file = constants.SAMPLE_GRIB
if variable == "RELHUM":
variable = "r"
sample_file = constants.SAMPLE_Z_GRIB

# Load the sample grib file
original_data = earthkit.data.from_source("file", sample_file)

subset = original_data.sel(shortName=grib_code, level=vertical)
md = subset.metadata()

# Cut the datestring into date and time and then override all
# values in md
date = date_str[:8]
time = date_str[8:]

for index, item in enumerate(md):
md[index] = item.override({"date": date}).override(
{"time": time}
)
if len(md) > 0:
# Load the array to replace the values with
replacement_data = np.load(file_path)
original_cut = replacement_data[
0, time_idx, :, min(value_range) : max(value_range) + 1
].reshape(
constants.GRID_SHAPE[1],
constants.GRID_SHAPE[0],
shape_val,
)
cut_values = np.moveaxis(
original_cut, [-3, -2, -1], [-1, -2, -3]
)
# Can we stack Fieldlists?
data_new = earthkit.data.FieldList.from_array(
cut_values, md
)
final_data += data_new
# Create the modified GRIB file with the predicted data
grib_path = os.path.join(
value_dir_path, f"prediction_{date_str}_grib"
)
final_data.save(grib_path)

def on_load_checkpoint(self, checkpoint):
"""
Expand Down
3 changes: 2 additions & 1 deletion neural_lam/weather_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,14 +257,15 @@ def __init__(
self.predict_dataset = None

def setup(self, stage=None):
if stage == "fit" or stage is None:
if stage == "train" or stage is None:
self.train_dataset = WeatherDataset(
dataset_name=self.dataset_name,
split="train",
standardize=self.standardize,
subset=self.subset,
batch_size=self.batch_size,
)
elif stage == "val":
self.val_dataset = WeatherDataset(
dataset_name=self.dataset_name,
split="val",
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ wandb>=0.13.10
matplotlib>=3.7.0
dask
dask_jobqueue
earthkit-data
scipy>=1.10.0
pytorch-lightning>=2.0.3
shapely>=2.0.1
Expand Down
1 change: 0 additions & 1 deletion slurm_param.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
#SBATCH --time=24:00:00
#SBATCH --nodes=2
#SBATCH --partition=postproc
#SBATCH --mem=444G
#SBATCH --no-requeue
#SBATCH --exclusive
#SBATCH --output=lightning_logs/neurwp_param_out.log
Expand Down
7 changes: 3 additions & 4 deletions slurm_predict.sh
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
#!/bin/bash -l
#SBATCH --job-name=NeurWPp
#SBATCH --account=s83
#SBATCH --partition=normal
#SBATCH --partition=pp-short
#SBATCH --nodes=1
#SBATCH --ntasks-per-node=4
#SBATCH --mem=444G
#SBATCH --time=00:59:00
#SBATCH --no-requeue
#SBATCH --output=lightning_logs/neurwp_pred_out.log
Expand Down Expand Up @@ -39,7 +38,7 @@ fi

echo "Predicting with model"
if [ "$MODEL" = "hi_lam" ]; then
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict"
srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 1 --batch_size 1 --subset_ds 1 --model hi_lam --graph hierarchical --load wandb/example.ckpt --eval="predict"
else
srun -ul python train_model.py --dataset $DATASET --val_interval 2 --epochs 1 --n_workers 12 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict"
srun -ul python train_model.py --dataset $DATASET --epochs 1 --n_workers 1 --batch_size 1 --subset_ds 1 --load "wandb/example.ckpt" --eval="predict"
fi
Loading