diff --git a/.ruff.toml b/.ruff.toml index e675595d..7ccca57c 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -1,10 +1,10 @@ # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default. -select = ["E", "F", "D", "I"] -ignore = ["D200","D202","D210","D212","D415","D105",] +lint.select = ["E", "F", "D", "I"] +lint.ignore = ["D200","D202","D210","D212","D415","D105"] # Allow autofix for all enabled rules (when `--fix`) is provided. -fixable = ["A", "B", "C", "D", "E", "F", "I"] -unfixable = [] +lint.fixable = ["A", "B", "C", "D", "E", "F", "I"] +lint.unfixable = [] # Exclude a variety of commonly ignored directories. exclude = [ @@ -35,22 +35,22 @@ exclude = [ line-length = 100 # Allow unused variables when underscore-prefixed. -dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +lint.dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" # Assume Python 3.10. target-version = "py311" -fix = false +fix=false # Group violations by containing file. output-format = "github" -ignore-init-module-imports = true +lint.ignore-init-module-imports = true -[mccabe] +[lint.mccabe] # Unlike Flake8, default to a complexity level of 10. max-complexity = 10 -[pydocstyle] +[lint.pydocstyle] # Use Google-style docstrings. convention = "google" -[per-file-ignores] +[lint.per-file-ignores] "__init__.py" = ["F401", "E402"] diff --git a/graph_weather/data/const.py b/graph_weather/data/const.py index e1d57b3d..e064cb38 100644 --- a/graph_weather/data/const.py +++ b/graph_weather/data/const.py @@ -6,7 +6,8 @@ 2. GFS Forecast Fields 3. ERA5 Reanalysis Fields -where the variance is the variance in the 3 hour change for a variable averaged across all lat/lon and pressure levels +where the variance is the variance in the 3 hour change for a variable averaged across all lat/lon + and pressure levels and time for (~100 random temporal frames, more the better) Min/Max/Mean/Stddev for all those plus each type of observation in observation files diff --git a/graph_weather/data/dataloader.py b/graph_weather/data/dataloader.py index 19b7d7a8..a5238d4e 100644 --- a/graph_weather/data/dataloader.py +++ b/graph_weather/data/dataloader.py @@ -2,8 +2,10 @@ The dataloader has to do a few things for the model to work correctly -1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the correct resolution -2. Calculate the top-of-atmosphere solar radiation for each location at fcurrent time and 10 other +1. Load the land-0sea mask, orography dataset, regridded from 0.1 to the +correct resolution +2. Calculate the top-of-atmosphere solar radiation for each location at +fcurrent time and 10 other times +- 12 hours 3. Add day-of-year, sin(lat), cos(lat), sin(lon), cos(lon) as well 3. Batch data as either in geometric batches, or more normally @@ -20,7 +22,26 @@ class AnalysisDataset(Dataset): + """ + Dataset class for analysis data. + + Args: + filepaths: List of file paths. + invariant_path: Path to the invariant file. + mean: Mean value. + std Standard deviation value. + coarsen : Coarsening factor. Defaults to 8. + + Methods: + __init__: Initialize the AnalysisDataset object. + __len__: Get the length of the dataset. + __getitem__: Get an item from the dataset. + """ + def __init__(self, filepaths, invariant_path, mean, std, coarsen: int = 8): + """ + Initialize the AnalysisDataset object. + """ super().__init__() self.filepaths = sorted(filepaths) self.invariant_path = invariant_path @@ -124,7 +145,8 @@ def __getitem__(self, item): ], axis=-1, ) - # Not want to predict non-physics variables -> Output only the data variables? Would be simpler, and just add in the new ones each time + # Not want to predict non-physics variables -> Output only the data variables? + # Would be simpler, and just add in the new ones each time output_data = np.stack( [ @@ -154,9 +176,13 @@ def __getitem__(self, item): obs_data = xr.open_zarr( "/home/jacob/Development/prepbufr.gdas.20160101.t00z.nr.48h.raw.zarr", consolidated=True ) -# TODO Embedding? These should stay consistent across all of the inputs, so can just load the values, not the strings? -# Should only take in the quality markers, observations, reported observation time relative to start point + +# TODO Embedding? These should stay consistent across all of the inputs, so can just load the values +# not the strings? +# Should only take in the quality markers, observations, reported observation time relative to start +# point # Observation errors, and background values, lat/lon/height/speed of observing thing + print(obs_data) print(obs_data.hdr_inst_typ.values) print(obs_data.hdr_irpt_typ.values) diff --git a/graph_weather/models/analysis.py b/graph_weather/models/analysis.py index d44abc11..fb707e3a 100644 --- a/graph_weather/models/analysis.py +++ b/graph_weather/models/analysis.py @@ -32,23 +32,23 @@ def __init__( Args: observation_lat_lons: Lat/lon points of the observations - output_lat_lons: List of latitude and longitudes for the output analysis - resolution: Resolution of the H3 grid, prefer even resolutions, as - odd ones have octogons and heptagons as well - observation_dim: Input feature size - analysis_dim: Output Analysis feature dim - node_dim: Node hidden dimension - edge_dim: Edge hidden dimension - num_blocks: Number of message passing blocks in the Processor - hidden_dim_processor_node: Hidden dimension of the node processors - hidden_dim_processor_edge: Hidden dimension of the edge processors - hidden_layers_processor_node: Number of hidden layers in the node processors - hidden_layers_processor_edge: Number of hidden layers in the edge processors - hidden_dim_decoder:Number of hidden dimensions in the decoder - hidden_layers_decoder: Number of layers in the decoder - norm_type: Type of norm for the MLPs - one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None - use_checkpointing: Whether to use gradient checkpointing or not + output_lat_lons: List of latitude and longitudes for the output analysis + resolution: Resolution of the H3 grid, prefer even resolutions, as + odd ones have octogons and heptagons as well + observation_dim: Input feature size + analysis_dim: Output Analysis feature dim + node_dim: Node hidden dimension + edge_dim: Edge hidden dimension + num_blocks: Number of message passing blocks in the Processor + hidden_dim_processor_node: Hidden dimension of the node processors + hidden_dim_processor_edge: Hidden dimension of the edge processors + hidden_layers_processor_node: Number of hidden layers in the node processors + hidden_layers_processor_edge: Number of hidden layers in the edge processors + hidden_dim_decoder:Number of hidden dimensions in the decoder + hidden_layers_decoder: Number of layers in the decoder + norm_type: Type of norm for the MLPs + one of 'LayerNorm', 'GraphNorm', 'InstanceNorm', 'BatchNorm', 'MessageNorm', or None + use_checkpointing: Whether to use gradient checkpointing or not """ super().__init__() diff --git a/graph_weather/models/forecast.py b/graph_weather/models/forecast.py index 267bc950..6ba6d523 100644 --- a/graph_weather/models/forecast.py +++ b/graph_weather/models/forecast.py @@ -39,7 +39,8 @@ def __init__( odd ones have octogons and heptagons as well feature_dim: Input feature size aux_dim: Number of non-NWP features (i.e. landsea mask, lat/lon, etc) - output_dim: Optional, output feature size, useful if want only subset of variables in output + output_dim: Optional, output feature size, useful if want only subset of variables in + output node_dim: Node hidden dimension edge_dim: Edge hidden dimension num_blocks: Number of message passing blocks in the Processor diff --git a/graph_weather/models/losses.py b/graph_weather/models/losses.py index c0d8035d..8ef9af19 100644 --- a/graph_weather/models/losses.py +++ b/graph_weather/models/losses.py @@ -25,6 +25,8 @@ def __init__( Args: feature_variance: Variance for each of the physical features lat_lons: List of lat/lon pairs, used to generate weighting + device: checks for device whether it supports gpu or not + normalize: option for normalize """ # TODO Rescale by nominal static air density at each pressure level super().__init__() diff --git a/train/deepspeed_graph.py b/train/deepspeed_graph.py index 23ede888..496db920 100644 --- a/train/deepspeed_graph.py +++ b/train/deepspeed_graph.py @@ -1,6 +1,9 @@ +"""Module for training the Graph Weather forecaster model using PyTorch Lightning.""" + import pytorch_lightning as pl import torch from pytorch_lightning import Trainer +from torch.utils.data import DataLoader, Dataset from graph_weather import GraphWeatherForecaster @@ -11,13 +14,42 @@ class LitModel(pl.LightningModule): + """ + LightningModule for the weather forecasting model. + + Args: + lat_lons: List of latitude and longitude coordinates. + feature_dim: Dimension of the input features. + aux_dim : Dimension of the auxiliary features. + + Methods: + __init__: Initialize the LitModel object. + """ + def __init__(self, lat_lons, feature_dim, aux_dim): + """ + Initialize the LitModel object. + + Args: + lat_lons: List of latitude and longitude coordinates. + feature_dim : Dimension of the input features. + aux_dim : Dimension of the auxiliary features. + """ super().__init__() self.model = GraphWeatherForecaster( lat_lons=lat_lons, feature_dim=feature_dim, aux_dim=aux_dim ) def training_step(self, batch): + """ + Performs a training step. + + Args: + batch: A batch of training data. + + Returns: + The computed loss. + """ x, y = batch x = x.half() y = y.half() @@ -27,18 +59,41 @@ def training_step(self, batch): return loss def configure_optimizers(self): + """ + Configures the optimizer used during training. + + Returns: + The optimizer. + """ return torch.optim.AdamW(self.parameters()) def forward(self, x): - return self.model(x) + """ + Forward pass. + Args: + x (torch.Tensor): Input data. -# Fake data -from torch.utils.data import DataLoader, Dataset + Returns: + torch.Tensor: Output of the model. + """ + return self.model(x) class FakeDataset(Dataset): + """ + Dataset class for generating fake data. + + Methods: + __init__: Initialize the FakeDataset object. + __len__: Return the length of the dataset. + __getitem__: Get an item from the dataset. + """ + def __init__(self): + """ + Initialize the FakeDataset object. + """ super(FakeDataset, self).__init__() def __len__(self): diff --git a/train/pl_graph_weather.py b/train/pl_graph_weather.py index d2588435..d82cf8cd 100644 --- a/train/pl_graph_weather.py +++ b/train/pl_graph_weather.py @@ -19,10 +19,25 @@ def worker_init_fn(worker_id): + """ + Initialize random seed for worker. + + Args: + worker_id (int): ID of the worker. + + Returns: + None + """ np.random.seed(np.random.get_state()[1][0] + worker_id) def get_mean_stds(): + """ + Calculate means and standard deviations for forecast variables. + + Returns: + means and standard deviations dict for forecast variables + """ names = [ "CLMR", "GRLE", @@ -151,6 +166,7 @@ def get_mean_stds(): def process_data(data): + """Process the input data.""" data.update( { key: np.expand_dims(np.asarray(value), axis=-1) @@ -232,7 +248,26 @@ def process_data(data): class GraphDataModule(pl.LightningDataModule): + """ + LightningDataModule for loading graph data. + + Attributes: + batch_size : Batch size for the dataloader. + dataset : Dataset containing the loaded data. + + Methods: + __init__: Initialize the GraphDataModule object by loading and processing data. + train_dataloader: Create training dataloader. + """ + def __init__(self, deg: str = "2.0", batch_size: int = 1): + """ + Initialize the GraphDataModule object by loading and processing data. + + Args: + deg : Resolution of the dataset. + batch_size : Batch size for the dataloader. + """ super().__init__() self.batch_size = batch_size self.dataset = datasets.load_dataset( @@ -258,10 +293,31 @@ def __init__(self, deg: str = "2.0", batch_size: int = 1): ) def train_dataloader(self): + """ + Create training dataloader. + + Returns: + torch.utils.data.DataLoader: Training dataloader. + """ return DataLoader(self.dataset, batch_size=self.batch_size, num_workers=2) class LitGraphForecaster(pl.LightningModule): + """ + LightningModule for graph-based weather forecasting. + + Attributes: + model (GraphWeatherForecaster): Graph weather forecaster model. + criterion (NormalizedMSELoss): Loss criterion for training. + lr : Learning rate for optimizer. + + Methods: + __init__: Initialize the LitGraphForecaster object. + forward: Forward pass of the model. + training_step: Training step. + configure_optimizers: Configure the optimizer for training. + """ + def __init__( self, lat_lons: list, @@ -271,6 +327,17 @@ def __init__( num_blocks: int = 3, lr: float = 3e-4, ): + """ + Initialize the LitGraphForecaster object with the required args. + + Args: + lat_lons : List of latitude and longitude values. + feature_dim : Dimensionality of the input features. + aux_dim : Dimensionality of auxiliary features. + hidden_dim : Dimensionality of hidden layers in the model. + num_blocks : Number of graph convolutional blocks in the model. + lr (float): Learning rate for optimizer. + """ super().__init__() self.model = GraphWeatherForecaster( lat_lons, @@ -288,9 +355,29 @@ def __init__( self.save_hyperparameters() def forward(self, x): + """ + Forward pass . + + Args: + x (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Output tensor. + """ return self.model(x) def training_step(self, batch, batch_idx): + """ + Training step. + + Args: + batch (dict): Batch of data containing input and output tensors. + batch_idx (int): Index of the current batch. + + Returns: + torch.Tensor: Loss tensor. + """ + x, y = batch["input"], batch["output"] if torch.isnan(x).any() or torch.isnan(y).any(): return None @@ -299,6 +386,12 @@ def training_step(self, batch, batch_idx): return loss def configure_optimizers(self): + """ + Configure the optimizer. + + Returns: + torch.optim.Optimizer: Optimizer instance. + """ return torch.optim.AdamW(self.parameters(), lr=self.lr) @@ -328,6 +421,18 @@ def configure_optimizers(self): type=click.INT, ) def run(num_blocks, hidden, batch, gpus): + """ + Trainig process. + + Args: + num_blocks : Number of blocks. + hidden: Hidden dimension. + batch : Batch size. + gpus : Number of GPUs. + + Returns: + None + """ hf_ds = datasets.load_dataset( "openclimatefix/gfs-surface-pressure-2deg", split="train", streaming=False ) diff --git a/train/run.py b/train/run.py index 91eb403b..fd7fd9ee 100644 --- a/train/run.py +++ b/train/run.py @@ -1,5 +1,7 @@ """Training script for training the weather forecasting model""" +import time + import datasets import numpy as np import pandas as pd @@ -19,10 +21,19 @@ def worker_init_fn(worker_id): + """ + Initialize the random seed for each worker. + + Args: + worker_id (int): The ID of the worker. + + Returns: + None + """ np.random.seed(np.random.get_state()[1][0] + worker_id) -def get_mean_stds(): +def get_mean_stds(): # noqa: D103 names = [ "CLMR", "GRLE", @@ -148,7 +159,29 @@ def get_mean_stds(): class XrDataset(IterableDataset): + """ + Dataset class for loading data from Hugging Face datasets. + + Attributes: + dataset : Dataset containing the loaded data. + means : Dictionary containing mean values. + stds : Dictionary containing standard deviation values. + landsea: Dataset containing land-sea mask data. + landsea_fixed : Tensor containing fixed land-sea mask data. + + Methods: + __init__: Initialize the XrDataset object by loading data from Hugging Face datasets. + __iter__: Iterate through the dataset. + """ + def __init__(self, resolution="2.0deg"): + """ + Initialize the XrDataset object by loading data from Hugging Face datasets. + + Args: + resolution : Resolution of the dataset. + + """ super().__init__() if "2deg" in resolution: LATITUDE = 91 @@ -333,7 +366,7 @@ def __iter__(self): seed=np.random.randint(low=-1000, high=10000), buffer_size=4 ) for data in iter(self.dataset): - # TODO Currently leaves out lat/lon/Sun irradience, and land/sea mask and topographic data + # TODO Currently leaves out lat/lon/Sun irradience, land/sea mask and topographic data data.update( { key: np.expand_dims(np.asarray(value), axis=-1) @@ -468,7 +501,7 @@ def __iter__(self): ).to(device) optimizer = optim.AdamW(model.parameters(), lr=0.001) print("Done Setup") -import time + for epoch in range(100): # loop over the dataset multiple times running_loss = 0.0 diff --git a/train/run_fulll.py b/train/run_fulll.py index 0acac1a3..a6a222f1 100644 --- a/train/run_fulll.py +++ b/train/run_fulll.py @@ -3,9 +3,7 @@ import json import os import sys - -BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.append(BASE_DIR) +import time import numpy as np import torch @@ -18,9 +16,28 @@ from graph_weather.data import const from graph_weather.models.losses import NormalizedMSELoss +BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +sys.path.append(BASE_DIR) + class XrDataset(Dataset): + """ + Dataset class for loading data from Hugging Face datasets. + + Attributes: + filepaths : List of file paths to the data. + data : Dataset containing the loaded data. + + Methods: + __init__: Initialize the XrDataset object by loading data from Hugging Face datasets. + __len__: Get the length of the dataset. + __getitem__: Get an item from the dataset by index. + """ + def __init__(self): + """ + Initialize the XrDataset object by loading data from Hugging Face datasets. + """ super().__init__() with open("hf_forecasts.json", "r") as f: files = json.load(f) @@ -110,7 +127,7 @@ def __getitem__(self, item): model = GraphWeatherForecaster(lat_lons, feature_dim=597, num_blocks=6).to(device) optimizer = optim.AdamW(model.parameters(), lr=0.001) print("Done Setup") -import time + for epoch in range(100): # loop over the dataset multiple times running_loss = 0.0