diff --git a/.gitignore b/.gitignore index d248bf98..e450bd8f 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ # pixi environments .pixi .vscode/ +checkpoints/ +lightning_logs/ diff --git a/environment_cpu.yml b/environment_cpu.yml index db246dba..383d4a1e 100644 --- a/environment_cpu.yml +++ b/environment_cpu.yml @@ -24,8 +24,11 @@ dependencies: - zarr - h3-py - numpy - - torch_harmonics + - pyshtools + - gcsfs + - pytest - pip: + - setuptools - datasets - einops - fsspec @@ -36,3 +39,4 @@ dependencies: - click - trimesh - rtree + - torch-harmonics diff --git a/environment_cuda.yml b/environment_cuda.yml index 0b86f8fb..e1eb18d4 100644 --- a/environment_cuda.yml +++ b/environment_cuda.yml @@ -6,7 +6,7 @@ channels: - conda-forge - defaults dependencies: - - pytorch-cuda=12.1 + - pytorch-cuda - numcodecs - pandas - pip @@ -25,8 +25,11 @@ dependencies: - zarr - h3-py - numpy - - torch_harmonics + - pyshtools + - gcsfs + - pytest - pip: + - setuptools - datasets - einops - fsspec @@ -37,3 +40,4 @@ dependencies: - click - trimesh - rtree + - torch-harmonics diff --git a/graph_weather/models/__init__.py b/graph_weather/models/__init__.py index ace964db..91909420 100644 --- a/graph_weather/models/__init__.py +++ b/graph_weather/models/__init__.py @@ -1,6 +1,12 @@ """Models""" -from .fengwu_ghr.layers import ImageMetaModel, MetaModel, WrapperImageModel, WrapperMetaModel +from .fengwu_ghr.layers import ( + ImageMetaModel, + LoRAModule, + MetaModel, + WrapperImageModel, + WrapperMetaModel, +) from .layers.assimilator_decoder import AssimilatorDecoder from .layers.assimilator_encoder import AssimilatorEncoder from .layers.decoder import Decoder diff --git a/graph_weather/models/fengwu_ghr/__init__.py b/graph_weather/models/fengwu_ghr/__init__.py new file mode 100644 index 00000000..39b921fd --- /dev/null +++ b/graph_weather/models/fengwu_ghr/__init__.py @@ -0,0 +1,3 @@ +"""Main import for FengWu-GHR""" + +from .layers import ImageMetaModel, LoRAModule, MetaModel, WrapperImageModel, WrapperMetaModel diff --git a/graph_weather/models/fengwu_ghr/layers.py b/graph_weather/models/fengwu_ghr/layers.py index d129d2dd..38cf43ab 100644 --- a/graph_weather/models/fengwu_ghr/layers.py +++ b/graph_weather/models/fengwu_ghr/layers.py @@ -22,6 +22,9 @@ def knn_interpolate( squared_distance = (diff * diff).sum(dim=-1, keepdim=True) weights = 1.0 / torch.clamp(squared_distance, min=1e-16) + y_idx, x_idx = y_idx.to(x.device), x_idx.to(x.device) + weights = weights.to(x.device) + den = scatter(weights, y_idx, 0, pos_y.size(0), reduce="sum") y = scatter(x[x_idx] * weights, y_idx, 0, pos_y.size(0), reduce="sum") @@ -228,6 +231,7 @@ def __init__( ) def forward(self, x): + assert x.shape[1] == self.channels, "Wrong number of channels" device = x.device dtype = x.dtype @@ -276,7 +280,7 @@ def __init__( super().__init__() self.i_h, self.i_w = pair(image_size) - self.pos_x = torch.tensor(lat_lons) + self.pos_x = torch.tensor(lat_lons).to(torch.long) self.pos_y = torch.cartesian_prod( (torch.arange(-self.i_h / 2, self.i_h / 2, 1) / self.i_h * 180).to(torch.long), (torch.arange(0, self.i_w, 1) / self.i_w * 360).to(torch.long), @@ -344,3 +348,45 @@ def forward(self, x): x = rearrange(x, "n (b c) -> b n c", b=b, c=c) return x + + +class LoRALayer(nn.Module): + def __init__(self, linear_layer: nn.Module, r: int): + """ + Initialize LoRALayer. + + Args: + linear_layer (nn.Module): Linear layer to be transformed. + r (int): rank of the low-rank matrix. + """ + super().__init__() + out_features, in_features = linear_layer.weight.shape + + self.A = nn.Parameter(torch.randn(r, in_features)) + self.B = nn.Parameter(torch.zeros(out_features, r)) + self.linear_layer = linear_layer + + def forward(self, x): + out = self.linear_layer(x) + self.B @ self.A @ x + return out + + +class LoRAModule(nn.Module): + def __init__(self, model, r=4): + """ + Initialize LoRAModule. + + Args: + model (nn.Module): Model to be modified with LoRA layers. + r (int, optional): Rank of LoRA layers. Defaults to 4. + """ + super().__init__() + for name, layer in model.named_modules(): + layer.eval() + if isinstance(layer, nn.Linear): + lora_layer = LoRALayer(layer, r) + setattr(model, name, lora_layer) + self.model = model + + def forward(self, x): + return self.model(x) diff --git a/tests/test_model.py b/tests/test_model.py index 94ed991b..9fae0a1d 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -252,7 +252,6 @@ def test_image_meta_model(): out = model(image) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == image.size() @@ -275,7 +274,6 @@ def test_wrapper_image_meta_model(): big_model = WrapperImageModel(model, scale_factor) out = big_model(big_image) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == big_image.size() @@ -303,5 +301,34 @@ def test_meta_model(): out = model(features) assert not torch.isnan(out).any() - assert not torch.isnan(out).any() assert out.size() == features.size() + + +def test_wrapper_meta_model(): + lat_lons = [] + for lat in range(-90, 90, 5): + for lon in range(0, 360, 5): + lat_lons.append((lat, lon)) + + batch = 2 + channels = 3 + image_size = 20 + patch_size = 4 + scale_factor = 3 + model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=1, + heads=1, + mlp_dim=7, + channels=channels, + dim_head=64, + ) + + big_features = torch.randn((batch, len(lat_lons), channels)) + big_model = WrapperMetaModel(lat_lons, model, scale_factor) + out = big_model(big_features) + + assert not torch.isnan(out).any() + assert out.size() == big_features.size() diff --git a/train/era5.py b/train/era5.py new file mode 100644 index 00000000..ab5eee2d --- /dev/null +++ b/train/era5.py @@ -0,0 +1,192 @@ +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +import xarray +from einops import rearrange +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader, Dataset + +from graph_weather.models import MetaModel +from graph_weather.models.losses import NormalizedMSELoss + + +class LitFengWuGHR(pl.LightningModule): + """ + LightningModule for graph-based weather forecasting. + + Attributes: + model (GraphWeatherForecaster): Graph weather forecaster model. + criterion (NormalizedMSELoss): Loss criterion for training. + lr : Learning rate for optimizer. + + Methods: + __init__: Initialize the LitFengWuGHR object. + forward: Forward pass of the model. + training_step: Training step. + configure_optimizers: Configure the optimizer for training. + """ + + def __init__( + self, + lat_lons: list, + *, + channels: int, + image_size, + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, + feature_dim: int = 605, # TODO where does this come from? + lr: float = 3e-4, + ): + """ + Initialize the LitFengWuGHR object with the required args. + + Args: + lat_lons : List of latitude and longitude values. + feature_dim : Dimensionality of the input features. + aux_dim : Dimensionality of auxiliary features. + hidden_dim : Dimensionality of hidden layers in the model. + num_blocks : Number of graph convolutional blocks in the model. + lr (float): Learning rate for optimizer. + """ + super().__init__() + self.model = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + ) + self.criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) + ) + self.lr = lr + self.save_hyperparameters() + + def forward(self, x): + """ + Forward pass . + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ + return self.model(x) + + def training_step(self, batch, batch_idx): + """ + Training step. + + Args: + batch (array): Batch of data containing input and output tensors. + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Loss tensor. + """ + x, y = batch[:, 0], batch[:, 1] + if torch.isnan(x).any() or torch.isnan(y).any(): + return None + y_hat = self.forward(x) + loss = self.criterion(y_hat, y) + self.log("loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + """ + Configure the optimizer. + + Returns: + torch.optim.Optimizer: Optimizer instance. + """ + return torch.optim.AdamW(self.parameters(), lr=self.lr) + + +class Era5Dataset(Dataset): + """Era5 dataset.""" + + def __init__(self, xarr, transform=None): + """ + Arguments: + #TODO + """ + ds = np.asarray(xarr.to_array()) + ds = torch.from_numpy(ds) + ds -= ds.min(0, keepdim=True)[0] + ds /= ds.max(0, keepdim=True)[0] + ds = rearrange(ds, "C T H W -> T (H W) C") + self.ds = ds + + def __len__(self): + return len(self.ds) - 1 + + def __getitem__(self, index): + return self.ds[index : index + 2] + + +if __name__ == "__main__": + + ckpt_path = Path("./checkpoints") + patch_size = 4 + grid_step = 20 + variables = [ + "2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind", + ] + + channels = len(variables) + ckpt_path.mkdir(parents=True, exist_ok=True) + + reanalysis = xarray.open_zarr( + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + storage_options=dict(token="anon"), + ) + + reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) + reanalysis = reanalysis.isel( + time=slice(100, 107), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) + ) + + reanalysis = reanalysis[variables] + print(f"size: {reanalysis.nbytes / (1024 ** 3)} GiB") + + lat_lons = np.array( + np.meshgrid( + np.asarray(reanalysis["latitude"]).flatten(), + np.asarray(reanalysis["longitude"]).flatten(), + ) + ).T.reshape((-1, 2)) + + checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") + + dset = DataLoader(Era5Dataset(reanalysis), batch_size=10, num_workers=8) + model = LitFengWuGHR( + lat_lons=lat_lons, + channels=channels, + image_size=(721 // grid_step, 1440 // grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5, + ) + trainer = pl.Trainer( + accelerator="gpu", + devices=-1, + max_epochs=100, + precision="16-mixed", + callbacks=[checkpoint_callback], + log_every_n_steps=3, + ) + + trainer.fit(model, dset) + + torch.save(model.model.state_dict(), ckpt_path / "best.pt") diff --git a/train/lora.py b/train/lora.py new file mode 100644 index 00000000..f829c797 --- /dev/null +++ b/train/lora.py @@ -0,0 +1,165 @@ +from pathlib import Path + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.nn as nn +import xarray +from einops import rearrange +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.utils.data import DataLoader, Dataset + +from graph_weather.models import LoRAModule, MetaModel +from graph_weather.models.losses import NormalizedMSELoss + + +class LitLoRAFengWuGHR(pl.LightningModule): + def __init__( + self, + lat_lons: list, + single_step_model_state_dict: dict, + *, + time_step: int, + rank: int, + channels: int, + image_size, + patch_size=4, + depth=5, + heads=4, + mlp_dim=5, + feature_dim: int = 605, # TODO where does this come from? + lr: float = 3e-4, + ): + super().__init__() + assert ( + time_step > 1 + ), "Time step must be greater than 1. Remember that 1 is the simple model time step." + ssmodel = MetaModel( + lat_lons, + image_size=image_size, + patch_size=patch_size, + depth=depth, + heads=heads, + mlp_dim=mlp_dim, + channels=channels, + ) + ssmodel.load_state_dict(single_step_model_state_dict) + self.models = nn.ModuleList( + [ssmodel] + [LoRAModule(ssmodel, r=rank) for _ in range(2, time_step + 1)] + ) + self.criterion = NormalizedMSELoss( + lat_lons=lat_lons, feature_variance=np.ones((feature_dim,)) + ) + self.lr = lr + self.save_hyperparameters() + + def forward(self, x): + ys = [] + for t, model in enumerate(self.models): + x = model(x) + ys.append(x) + return torch.stack(ys, dim=1) + + def training_step(self, batch, batch_idx): + if torch.isnan(batch).any(): + return None + x, ys = batch[:, 0, ...], batch[:, 1:, ...] + + y_hat = self.forward(x) + loss = self.criterion(y_hat, ys) + self.log("loss", loss, prog_bar=True) + return loss + + def configure_optimizers(self): + return torch.optim.AdamW(self.parameters(), lr=self.lr) + + +class Era5Dataset(Dataset): + + def __init__(self, xarr, time_step=1, transform=None): + assert time_step > 0, "Time step must be greater than 0." + ds = np.asarray(xarr.to_array()) + ds = torch.from_numpy(ds) + ds -= ds.min(0, keepdim=True)[0] + ds /= ds.max(0, keepdim=True)[0] + ds = rearrange(ds, "C T H W -> T (H W) C") + self.ds = ds + self.time_step = time_step + + def __len__(self): + return len(self.ds) - self.time_step + + def __getitem__(self, index): + return self.ds[index : index + time_step + 1] + + +if __name__ == "__main__": + + ckpt_path = Path("./checkpoints") + ckpt_name = "best.pt" + patch_size = 4 + grid_step = 20 + time_step = 2 + rank = 4 + variables = [ + "2m_temperature", + "surface_pressure", + "10m_u_component_of_wind", + "10m_v_component_of_wind", + ] + + ############################################################### + + channels = len(variables) + ckpt_path.mkdir(parents=True, exist_ok=True) + + reanalysis = xarray.open_zarr( + "gs://gcp-public-data-arco-era5/ar/full_37-1h-0p25deg-chunk-1.zarr-v3", + storage_options=dict(token="anon"), + ) + + reanalysis = reanalysis.sel(time=slice("2020-01-01", "2021-01-01")) + reanalysis = reanalysis.isel( + time=slice(100, 111), longitude=slice(0, 1440, grid_step), latitude=slice(0, 721, grid_step) + ) + + reanalysis = reanalysis[variables] + print(f"size: {reanalysis.nbytes / (1024 ** 3)} GiB") + + lat_lons = np.array( + np.meshgrid( + np.asarray(reanalysis["latitude"]).flatten(), + np.asarray(reanalysis["longitude"]).flatten(), + ) + ).T.reshape((-1, 2)) + + checkpoint_callback = ModelCheckpoint(dirpath=ckpt_path, save_top_k=1, monitor="loss") + + dset = DataLoader(Era5Dataset(reanalysis, time_step=time_step), batch_size=10, num_workers=8) + + single_step_model_state_dict = torch.load(ckpt_path / ckpt_name) + + model = LitLoRAFengWuGHR( + lat_lons=lat_lons, + single_step_model_state_dict=single_step_model_state_dict, + time_step=time_step, + rank=rank, + ########## + channels=channels, + image_size=(721 // grid_step, 1440 // grid_step), + patch_size=patch_size, + depth=5, + heads=4, + mlp_dim=5, + ) + trainer = pl.Trainer( + accelerator="gpu", + devices=-1, + max_epochs=100, + precision="16-mixed", + callbacks=[checkpoint_callback], + log_every_n_steps=3, + strategy="ddp_find_unused_parameters_true", + ) + + trainer.fit(model, dset)