From b9c1e300d17c5d4510fe1afe1c54f75cbc079b7a Mon Sep 17 00:00:00 2001 From: Lorenzo Breschi Date: Mon, 9 Sep 2024 11:42:34 +0000 Subject: [PATCH] era5 training bugfix --- train/era5.py | 50 +++++++++++++++++++++++++++++--------------------- 1 file changed, 29 insertions(+), 21 deletions(-) diff --git a/train/era5.py b/train/era5.py index 343d597..a83be4a 100644 --- a/train/era5.py +++ b/train/era5.py @@ -12,8 +12,10 @@ from einops import rearrange +from pathlib import Path -class LitGraphForecaster(pl.LightningModule): + +class LitFengWuGHR(pl.LightningModule): """ LightningModule for graph-based weather forecasting. @@ -23,7 +25,7 @@ class LitGraphForecaster(pl.LightningModule): lr : Learning rate for optimizer. Methods: - __init__: Initialize the LitGraphForecaster object. + __init__: Initialize the LitFengWuGHR object. forward: Forward pass of the model. training_step: Training step. configure_optimizers: Configure the optimizer for training. @@ -44,7 +46,7 @@ def __init__( ): """ - Initialize the LitGraphForecaster object with the required args. + Initialize the LitFengWuGHR object with the required args. Args: lat_lons : List of latitude and longitude values. @@ -135,16 +137,27 @@ def __getitem__(self, index): if __name__ == "__main__": + ckpt_path = Path("./checkpoints") patch_size = 4 grid_step = 20 + variables = ["2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind"] + + channels = len(variables) + ckpt_path.mkdir(parents=True, exist_ok=True) reanalysis = xarray.open_zarr( 'gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3', storage_options=dict(token='anon'), - ) - reanalysis = reanalysis.isel(time=slice(100, 400), longitude=slice( + + reanalysis = reanalysis.sel(time=slice('2020-01-01', '2021-01-01')) + reanalysis = reanalysis.isel(time=slice(100,107), longitude=slice( 0, 1440, grid_step), latitude=slice(0, 721, grid_step)) + + reanalysis = reanalysis[variables] print(f'size: {reanalysis.nbytes / (1024 ** 3)} GiB') lat_lons = np.array( @@ -155,27 +168,20 @@ def __getitem__(self, index): ).T.reshape((-1, 2)) checkpoint_callback = ModelCheckpoint( - dirpath="./checkpoints", save_top_k=1, monitor="loss") - reanalysis = reanalysis[["2m_temperature", - "surface_pressure", - "10m_u_component_of_wind", - "10m_v_component_of_wind"]] - - shape = np.asarray(reanalysis.to_array()).shape - channels = shape[0] + dirpath=ckpt_path, save_top_k=1, monitor="loss") dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) - model = LitGraphForecaster(lat_lons=lat_lons, - channels=channels, - image_size=(721//grid_step, 1440//grid_step), - patch_size=patch_size, - depth=5, - heads=4, - mlp_dim=5) + model = LitFengWuGHR(lat_lons=lat_lons, + channels=channels, + image_size=(721//grid_step, 1440//grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5) trainer = pl.Trainer( accelerator="gpu", devices=-1, - max_epochs=1000, + max_epochs=100, precision="16-mixed", callbacks=[checkpoint_callback], log_every_n_steps=3 @@ -183,3 +189,5 @@ def __getitem__(self, index): ) trainer.fit(model, dset) + + torch.save(model.state_dict(), ckpt_path / "best.pt")