Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fengwu ghr training #125

Merged
merged 54 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 53 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
cb70551
fengwu_ghr: initial
rnwzd May 29, 2024
9eaf70d
fengwu_ghr: fixes
rnwzd May 29, 2024
4f3d4c1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
8c60fb7
Interpolate initial
rnwzd Jun 6, 2024
725421d
ImageMetaModel
rnwzd Jun 11, 2024
c57a27e
MetaModel initial
rnwzd Jun 11, 2024
3d2a17d
tested metamodel
rnwzd Jun 14, 2024
48e7d0a
Merge branch 'main' into fengwu_ghr
rnwzd Jun 17, 2024
87d1ffd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
21d84c7
wrapper meta model
rnwzd Jun 21, 2024
07a8d0f
RES
rnwzd Jul 1, 2024
cd84968
load RES state_dict
rnwzd Jul 2, 2024
e0f60ca
Merge branch 'main' of https://github.com/openclimatefix/graph_weathe…
rnwzd Jul 2, 2024
b15110f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
1146db9
bug fix
rnwzd Jul 2, 2024
499ef7d
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 2, 2024
fe82edc
Merge branch 'main' of https://github.com/openclimatefix/graph_weathe…
rnwzd Jul 2, 2024
325fd0e
bug fix
rnwzd Jul 2, 2024
cfa9c3f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
2fadf97
env yml fix
rnwzd Jul 29, 2024
0f46d7a
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 29, 2024
257b353
fengwu_ghr: initial
rnwzd May 29, 2024
04d4776
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 29, 2024
f72c610
test_wrapper_meta_model
rnwzd Jul 29, 2024
8a8ac64
tests fix
rnwzd Jul 30, 2024
6f0c61b
parent 743cf9704eea03353e02351cde52add7233437d6
rnwzd May 29, 2024
1ae62eb
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 30, 2024
a855c6a
fengwu_ghr: initial
rnwzd May 29, 2024
47d0d48
fengwu_ghr: initial
rnwzd May 29, 2024
7a1d562
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 29, 2024
e397941
Interpolate initial
rnwzd Jun 6, 2024
80a73ee
ImageMetaModel
rnwzd Jun 11, 2024
127d8ff
MetaModel initial
rnwzd Jun 11, 2024
b59e54d
tested metamodel
rnwzd Jun 14, 2024
19e73a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jun 17, 2024
e54016f
wrapper meta model
rnwzd Jun 21, 2024
dae738e
RES
rnwzd Jul 1, 2024
92acbee
load RES state_dict
rnwzd Jul 2, 2024
5115dd4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jul 2, 2024
8a6a062
Merge branch 'fengwu_ghr' of https://github.com/openclimatefix/graph_…
rnwzd Jul 30, 2024
362acdc
added gcsfs to env yml
rnwzd Aug 1, 2024
4dc0dc5
__init__.py imports
rnwzd Aug 7, 2024
31ee1e9
MetaModel long coordinates
rnwzd Aug 15, 2024
2970757
knn_interpolate gpu patch
rnwzd Aug 15, 2024
9f84835
era5 training
rnwzd Sep 3, 2024
b9c1e30
era5 training bugfix
rnwzd Sep 9, 2024
eba1335
lora training
rnwzd Sep 18, 2024
0cfbf40
Merge branch 'main' of https://github.com/openclimatefix/graph_weathe…
rnwzd Sep 18, 2024
7915c9b
pkg does not exist
rnwzd Sep 18, 2024
085ae70
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 18, 2024
8944050
env.yml bugfix
rnwzd Sep 19, 2024
968b908
Update environment_cpu.yml
jacobbieker Sep 23, 2024
7eb323a
Update environment_cuda.yml
jacobbieker Sep 23, 2024
384d99e
Update environment_cuda.yml
jacobbieker Sep 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@
# pixi environments
.pixi
.vscode/
checkpoints/
lightning_logs/
6 changes: 5 additions & 1 deletion environment_cpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ dependencies:
- zarr
- h3-py
- numpy
- torch_harmonics
- pyshtools
- gcsfs
- pytest
- pip:
- setuptools
- datasets
- einops
- fsspec
Expand All @@ -36,3 +39,4 @@ dependencies:
- click
- trimesh
- rtree
- torch-harmonics
10 changes: 7 additions & 3 deletions environment_cuda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@ channels:
- conda-forge
- defaults
dependencies:
- pytorch-cuda=12.1
- pytorch-cuda
- numcodecs
- pandas
- pip
- pyg
- python=3.12
- python
- pytorch
- pytorch-cluster
- pytorch-scatter
Expand All @@ -25,8 +25,11 @@ dependencies:
- zarr
- h3-py
- numpy
- torch_harmonics
- pyshtools
- gcsfs
- pytest
- pip:
- setuptools
- datasets
- einops
- fsspec
Expand All @@ -37,3 +40,4 @@ dependencies:
- click
- trimesh
- rtree
- torch-harmonics
8 changes: 7 additions & 1 deletion graph_weather/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
"""Models"""

from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel
from .fengwu_ghr.layers import (
ImageMetaModel,
LoRAModule,
MetaModel,
WrapperImageModel,
WrapperMetaModel,
)
from .layers.assimilator_decoder import AssimilatorDecoder
from .layers.assimilator_encoder import AssimilatorEncoder
from .layers.decoder import Decoder
Expand Down
3 changes: 3 additions & 0 deletions graph_weather/models/fengwu_ghr/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Main import for FengWu-GHR"""

from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel
48 changes: 47 additions & 1 deletion graph_weather/models/fengwu_ghr/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ def knn_interpolate(
squared_distance = (diff * diff).sum(dim=-1, keepdim=True)
weights = 1.0 / torch.clamp(squared_distance, min=1e-16)

y_idx, x_idx = y_idx.to(x.device), x_idx.to(x.device)
weights = weights.to(x.device)

den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum")
y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum")

Expand Down Expand Up @@ -228,6 +231,7 @@ def __init__(
)

def forward(self, x):
assert x.shape[1] == self.channels, "Wrong number of channels"
device = x.device
dtype = x.dtype

Expand Down Expand Up @@ -276,7 +280,7 @@ def __init__(
super().__init__()
self.i_h, self.i_w = pair(image_size)

self.pos_x = torch.tensor(lat_lons)
self.pos_x = torch.tensor(lat_lons).to(torch.long)
self.pos_y = torch.cartesian_prod(
(torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long),
(torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long),
Expand Down Expand Up @@ -344,3 +348,45 @@ def forward(self, x):
x = rearrange(x, "n (b c) -> b n c", b=b, c=c)

return x


class LoRALayer(nn.Module):
def __init__(self, linear_layer: nn.Module, r: int):
"""
Initialize LoRALayer.

Args:
linear_layer (nn.Module): Linear layer to be transformed.
r (int): rank of the low-rank matrix.
"""
super().__init__()
out_features, in_features = linear_layer.weight.shape

self.A = nn.Parameter(torch.randn(r, in_features))
self.B = nn.Parameter(torch.zeros(out_features, r))
self.linear_layer = linear_layer

def forward(self, x):
out = self.linear_layer(x) + self.B @ self.A @ x
return out


class LoRAModule(nn.Module):
def __init__(self, model, r=4):
"""
Initialize LoRAModule.

Args:
model (nn.Module): Model to be modified with LoRA layers.
r (int, optional): Rank of LoRA layers. Defaults to 4.
"""
super().__init__()
for name, layer in model.named_modules():
layer.eval()
if isinstance(layer, nn.Linear):
lora_layer = LoRALayer(layer, r)
setattr(model, name, lora_layer)
self.model = model

def forward(self, x):
return self.model(x)
33 changes: 30 additions & 3 deletions tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,6 @@ def test_image_meta_model():

out = model(image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == image.size()


Expand All @@ -275,7 +274,6 @@ def test_wrapper_image_meta_model():
big_model = WrapperImageModel(model, scale_factor)
out = big_model(big_image)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == big_image.size()


Expand Down Expand Up @@ -303,5 +301,34 @@ def test_meta_model():

out = model(features)
assert not torch.isnan(out).any()
assert not torch.isnan(out).any()
assert out.size() == features.size()


def test_wrapper_meta_model():
lat_lons = []
for lat in range(-90, 90, 5):
for lon in range(0, 360, 5):
lat_lons.append((lat, lon))

batch = 2
channels = 3
image_size = 20
patch_size = 4
scale_factor = 3
model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=1,
heads=1,
mlp_dim=7,
channels=channels,
dim_head=64,
)

big_features = torch.randn((batch, len(lat_lons), channels))
big_model = WrapperMetaModel(lat_lons, model, scale_factor)
out = big_model(big_features)

assert not torch.isnan(out).any()
assert out.size() == big_features.size()
192 changes: 192 additions & 0 deletions train/era5.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
import xarray
from einops import rearrange
from pytorch_lightning.callbacks import ModelCheckpoint
from torch.utils.data import DataLoader, Dataset

from graph_weather.models import MetaModel
from graph_weather.models.losses import NormalizedMSELoss


class LitFengWuGHR(pl.LightningModule):
"""
LightningModule for graph-based weather forecasting.

Attributes:
model (GraphWeatherForecaster): Graph weather forecaster model.
criterion (NormalizedMSELoss): Loss criterion for training.
lr : Learning rate for optimizer.

Methods:
__init__: Initialize the LitFengWuGHR object.
forward: Forward pass of the model.
training_step: Training step.
configure_optimizers: Configure the optimizer for training.
"""

def __init__(
self,
lat_lons: list,
*,
channels: int,
image_size,
patch_size=4,
depth=5,
heads=4,
mlp_dim=5,
feature_dim: int = 605, # TODO where does this come from?
lr: float = 3e-4,
):
"""
Initialize the LitFengWuGHR object with the required args.

Args:
lat_lons : List of latitude and longitude values.
feature_dim : Dimensionality of the input features.
aux_dim : Dimensionality of auxiliary features.
hidden_dim : Dimensionality of hidden layers in the model.
num_blocks : Number of graph convolutional blocks in the model.
lr (float): Learning rate for optimizer.
"""
super().__init__()
self.model = MetaModel(
lat_lons,
image_size=image_size,
patch_size=patch_size,
depth=depth,
heads=heads,
mlp_dim=mlp_dim,
channels=channels,
)
self.criterion = NormalizedMSELoss(
lat_lons=lat_lons, feature_variance=np.ones((feature_dim,))
)
self.lr = lr
self.save_hyperparameters()

def forward(self, x):
"""
Forward pass .

Args:
x (torch.Tensor): Input tensor.

Returns:
torch.Tensor: Output tensor.
"""
return self.model(x)

def training_step(self, batch, batch_idx):
"""
Training step.

Args:
batch (array): Batch of data containing input and output tensors.
batch_idx (int): Index of the current batch.

Returns:
torch.Tensor: Loss tensor.
"""
x, y = batch[:, 0], batch[:, 1]
if torch.isnan(x).any() or torch.isnan(y).any():
return None
y_hat = self.forward(x)
loss = self.criterion(y_hat, y)
self.log("loss", loss, prog_bar=True)
return loss

def configure_optimizers(self):
"""
Configure the optimizer.

Returns:
torch.optim.Optimizer: Optimizer instance.
"""
return torch.optim.AdamW(self.parameters(), lr=self.lr)


class Era5Dataset(Dataset):
"""Era5 dataset."""

def __init__(self, xarr, transform=None):
"""
Arguments:
#TODO
"""
ds = np.asarray(xarr.to_array())
ds = torch.from_numpy(ds)
ds -= ds.min(0, keepdim=True)[0]
ds /= ds.max(0, keepdim=True)[0]
ds = rearrange(ds, "C T H W -> T (H W) C")
self.ds = ds

def __len__(self):
return len(self.ds) - 1

def __getitem__(self, index):
return self.ds[index : index + 2]


if __name__ == "__main__":

ckpt_path = Path("./checkpoints")
patch_size = 4
grid_step = 20
variables = [
"2m_temperature",
"surface_pressure",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
]

channels = len(variables)
ckpt_path.mkdir(parents=True, exist_ok=True)

reanalysis = xarray.open_zarr(
"gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3",
storage_options=dict(token="anon"),
)

reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01"))
reanalysis = reanalysis.isel(
time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step)
)

reanalysis = reanalysis[variables]
print(f"size: {reanalysis.nbytes / (1024 ** 3)} GiB")

lat_lons = np.array(
np.meshgrid(
np.asarray(reanalysis["latitude"]).flatten(),
np.asarray(reanalysis["longitude"]).flatten(),
)
).T.reshape((-1, 2))

checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss")

dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8)
model = LitFengWuGHR(
lat_lons=lat_lons,
channels=channels,
image_size=(721 // grid_step, 1440 // grid_step),
patch_size=patch_size,
depth=5,
heads=4,
mlp_dim=5,
)
trainer = pl.Trainer(
accelerator="gpu",
devices=-1,
max_epochs=100,
precision="16-mixed",
callbacks=[checkpoint_callback],
log_every_n_steps=3,
)

trainer.fit(model, dset)

torch.save(model.model.state_dict(), ckpt_path / "best.pt")
Loading
Loading