diff --git a/graph_weather/data/gencast_dataloader.py b/graph_weather/data/gencast_dataloader.py index 1aa2f22a..460765f6 100644 --- a/graph_weather/data/gencast_dataloader.py +++ b/graph_weather/data/gencast_dataloader.py @@ -20,15 +20,6 @@ class GenCastDataset(Dataset): """ Dataset class for GenCast training data. - - Args: - obs_path: dataset path. - atmospheric_features: list of features depending on pressure levels. - single_features: list of features not depending on pressure levels. - static_features: list of features not depending on time. - max_year: max year to include in training set. Defaults to 2018. - time_step: time step between predictions. - E.g. 12h steps correspond to time_step = 2 in a 6h dataset. Defaults to 2. """ def __init__( @@ -42,17 +33,32 @@ def __init__( ): """ Initialize the GenCast dataset object. + + Args: + obs_path: dataset path. + atmospheric_features: list of features depending on pressure levels. + single_features: list of features not depending on pressure levels. + static_features: list of features not depending on time. + max_year: max year to include in training set. Defaults to 2018. + time_step: time step between predictions. + E.g. 12h steps correspond to time_step = 2 in a 6h dataset. Defaults to 2. """ super().__init__() self.data = xr.open_zarr(obs_path, chunks={}) self.max_year = max_year - self.num_lon = len(self.data["longitude"].values) - self.num_lat = len(self.data["latitude"].values) + self.grid_lon = self.data["longitude"].values + self.grid_lat = self.data["latitude"].values + self.num_lon = len(self.grid_lon) + self.num_lat = len(self.grid_lat) self.num_vars = len(self.data.keys()) self.pressure_levels = np.array(self.data["level"].values).astype( np.float32 ) # Need them for loss weighting + self.output_features_dim = len(atmospheric_features) * len(self.pressure_levels) + len( + single_features + ) + self.input_features_dim = self.output_features_dim + len(static_features) + 4 self.time_step = time_step # e.g. 12h steps correspond to time_step = 2 in a 6h dataset @@ -161,7 +167,7 @@ def __getitem__(self, item): # Concatenate timesteps inputs = np.concatenate([inputs[0, :, :, :], inputs[1, :, :, :]], axis=-1) - inputs = np.nan_to_num(inputs).astype(np.float32) + prev_inputs = np.nan_to_num(inputs).astype(np.float32) # Load target data ds_target_atm = ( @@ -186,16 +192,16 @@ def __getitem__(self, item): target_residuals = np.nan_to_num(target_residuals).astype(np.float32) # Corrupt targets with noise - noise_level = np.array([sample_noise_level()]).astype(np.float32) + noise_levels = np.array([sample_noise_level()]).astype(np.float32) noise = generate_isotropic_noise( num_lat=self.num_lat, num_samples=target_residuals.shape[-1] ) - corrupted_residuals = target_residuals + noise_level * noise + corrupted_targets = target_residuals + noise_levels * noise return ( - inputs, - noise_level, - corrupted_residuals, + corrupted_targets, + prev_inputs, + noise_levels, target_residuals, ) @@ -366,7 +372,7 @@ def __getitem__(self, item): inputs = np.concatenate([batched_inputs_norm, ds_clock], axis=-1) # Concatenate timesteps inputs = np.concatenate([inputs[:, 0, :, :, :], inputs[:, 1, :, :, :]], axis=-1) - inputs = np.nan_to_num(inputs).astype(np.float32) + prev_inputs = np.nan_to_num(inputs).astype(np.float32) # Compute targets residuals raw_targets = np.concatenate([ds_atm, ds_single], axis=-1) @@ -376,13 +382,13 @@ def __getitem__(self, item): # Corrupt targets with noise noise_levels = np.zeros((self.batch_size, 1), dtype=np.float32) - corrupted_residuals = np.zeros_like(target_residuals, dtype=np.float32) + corrupted_targets = np.zeros_like(target_residuals, dtype=np.float32) for b in range(self.batch_size): noise_level = sample_noise_level() noise = generate_isotropic_noise( num_lat=self.num_lat, num_samples=target_residuals.shape[-1] ) - corrupted_residuals[b] = target_residuals[b] + noise_level * noise + corrupted_targets[b] = target_residuals[b] + noise_level * noise noise_levels[b] = noise_level - return (inputs, noise_levels, corrupted_residuals, target_residuals) + return (corrupted_targets, prev_inputs, noise_levels, target_residuals) diff --git a/graph_weather/models/gencast/__init__.py b/graph_weather/models/gencast/__init__.py index 1edcbc86..b83b9290 100644 --- a/graph_weather/models/gencast/__init__.py +++ b/graph_weather/models/gencast/__init__.py @@ -1,4 +1,5 @@ """Main import for GenCast""" +from .denoiser import Denoiser from .graph.graph_builder import GraphBuilder from .weighted_mse_loss import WeightedMSELoss diff --git a/graph_weather/models/gencast/denoiser.py b/graph_weather/models/gencast/denoiser.py new file mode 100644 index 00000000..a70700b9 --- /dev/null +++ b/graph_weather/models/gencast/denoiser.py @@ -0,0 +1,216 @@ +"""Denoiser. + +The denoiser takes as inputs the previous two timesteps, the corrupted target residual, and the +noise level, and outputs the denoised predictions. It performs the following tasks: +1. Initializes the graph, encoder, processor, and decoder. +2. Computes f_theta as the combination of encoder, processor, and decoder. +3. Preconditions f_theta on the noise levels using the parametrization from Karras et al. (2022). +""" + +import einops +import numpy as np +import torch +from torch_geometric.data import Batch + +from graph_weather.models.gencast.graph.graph_builder import GraphBuilder +from graph_weather.models.gencast.layers.decoder import Decoder +from graph_weather.models.gencast.layers.encoder import Encoder +from graph_weather.models.gencast.utils.noise import Preconditioner + + +class Denoiser(torch.nn.Module): + """GenCast's Denoiser.""" + + def __init__( + self, + grid_lon: np.ndarray, + grid_lat: np.ndarray, + input_features_dim: int, + output_features_dim: int, + hidden_dims: list[int] = [512, 512], + num_blocks: int = 16, + num_heads: int = 4, + splits: int = 6, + num_hops: int = 6, + device: torch.device = torch.device("cpu"), + ): + """Initialize the Denoiser. + + Args: + grid_lon (np.ndarray): array of longitudes. + grid_lat (np.ndarray): array of latitudes. + input_features_dim (int): dimension of the input features for a single timestep. + output_features_dim (int): dimension of the target features. + hidden_dims (list[int], optional): list of dimensions for the hidden layers in the MLPs + used in GenCast. This also determines the latent dimension. Defaults to [512, 512]. + num_blocks (int, optional): number of transformer blocks in Processor. Defaults to 16. + num_heads (int, optional): number of heads for each transformer. Defaults to 4. + splits (int, optional): number of time to split the icosphere during graph building. + Defaults to 6. + num_hops (int, optional): the transformes will attention to the (2^num_hops)-neighbours + of each node. Defaults to 6. + device (torch.device, optional): device on which we want to build graph. + Defaults to torch.device("cpu"). + """ + super().__init__() + self.num_lon = len(grid_lon) + self.num_lat = len(grid_lat) + self.input_features_dim = input_features_dim + self.output_features_dim = output_features_dim + + # Initialize graph + self.graphs = GraphBuilder( + grid_lon=grid_lon, + grid_lat=grid_lat, + splits=splits, + num_hops=num_hops, + device=device, + ) + + # Initialize Encoder + self.encoder = Encoder( + grid_dim=output_features_dim + 2 * input_features_dim + self.graphs.grid_nodes_dim, + mesh_dim=self.graphs.mesh_nodes_dim, + edge_dim=self.graphs.g2m_edges_dim, + hidden_dims=hidden_dims, + activation_layer=torch.nn.SiLU, + use_layer_norm=True, + ) + + # Initialize Decoder + self.decoder = Decoder( + edges_dim=self.graphs.m2g_edges_dim, + output_dim=output_features_dim, + hidden_dims=hidden_dims, + activation_layer=torch.nn.SiLU, + use_layer_norm=True, + ) + + # Initialize preconditioning functions + self.precs = Preconditioner(sigma_data=1.0) + + def _check_shapes(self, corrupted_targets, prev_inputs, noise_levels): + batch_size = prev_inputs.shape[0] + exp_inputs_shape = (batch_size, self.num_lon, self.num_lat, 2 * self.input_features_dim) + exp_targets_shape = (batch_size, self.num_lon, self.num_lat, self.output_features_dim) + exp_noise_shape = (batch_size, 1) + + if not all( + [ + corrupted_targets.shape == exp_targets_shape, + prev_inputs.shape == exp_inputs_shape, + noise_levels.shape == exp_noise_shape, + ] + ): + raise ValueError( + "The shapes of the input tensors don't match with the initialization parameters: " + f"expected {exp_inputs_shape} for prev_inputs, {exp_targets_shape} for targets and " + f"{exp_noise_shape} for noise_levels." + ) + + def _run_encoder(self, grid_features): + # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. + batch_size = grid_features.shape[0] + g2m_batched = Batch.from_data_list([self.graphs.g2m_graph] * batch_size) + + # load features. + grid_features = einops.rearrange(grid_features, "b n f -> (b n) f") + input_grid_nodes = torch.cat([grid_features, g2m_batched["grid_nodes"].x], dim=-1).type( + torch.float32 + ) + input_mesh_nodes = g2m_batched["mesh_nodes"].x + input_edge_attr = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_attr + edge_index = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_index + + # run the encoder. + latent_grid_nodes, latent_mesh_nodes = self.encoder( + input_grid_nodes=input_grid_nodes, + input_mesh_nodes=input_mesh_nodes, + input_edge_attr=input_edge_attr, + edge_index=edge_index, + ) + + # restore nodes dimension: [b, n, f] + latent_grid_nodes = einops.rearrange(latent_grid_nodes, "(b n) f -> b n f", b=batch_size) + latent_mesh_nodes = einops.rearrange(latent_mesh_nodes, "(b n) f -> b n f", b=batch_size) + return latent_grid_nodes, latent_mesh_nodes + + def _run_decoder(self, latent_mesh_nodes, latent_grid_nodes): + # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. + batch_size = latent_mesh_nodes.shape[0] + m2g_batched = Batch.from_data_list([self.graphs.m2g_graph] * batch_size) + + # load features. + input_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f") + input_grid_nodes = einops.rearrange(latent_grid_nodes, "b n f -> (b n) f") + input_edge_attr = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_attr + edge_index = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_index + + # run the decoder. + output_grid_nodes = self.decoder( + input_mesh_nodes=input_mesh_nodes, + input_grid_nodes=input_grid_nodes, + input_edge_attr=input_edge_attr, + edge_index=edge_index, + ) + + # restore nodes dimension: [b, n, f] + output_grid_nodes = einops.rearrange(output_grid_nodes, "(b n) f -> b n f", b=batch_size) + return output_grid_nodes + + def _run_processor(self, latent_mesh_nodes, noise_levels): + # TODO: add processor. + return latent_mesh_nodes + + def _f_theta(self, grid_features, noise_levels): + # run encoder, processor and decoder. + latent_grid_nodes, latent_mesh_nodes = self._run_encoder(grid_features) + latent_mesh_nodes = self._run_processor(latent_mesh_nodes, noise_levels) + output_grid_nodes = self._run_decoder(latent_mesh_nodes, latent_grid_nodes) + return output_grid_nodes + + def forward( + self, corrupted_targets: torch.Tensor, prev_inputs: torch.Tensor, noise_levels: torch.Tensor + ) -> torch.Tensor: + """Compute the denoiser output. + + The denoiser is a version of the (encoder, processor, decoder)-model (called f_theta), + preconditioned on the noise levels, as described below: + + D(Z, X, sigma) := c_skip(sigma)Z + c_out(sigma) * f_theta(c_in(sigma)Z, X, c_noise(sigma)), + + where Z is the corrupted target, X is the previous two timesteps concatenated and sigma is + the noise level used for Z's corruption. + + Args: + corrupted_targets (torch.Tensor): the target residuals corrupted by noise. + prev_inputs (torch.Tensor): the previous two timesteps concatenated across the features' + dimension. + noise_levels (torch.Tensor): the noise level used for corruption. + """ + # check shapes. + self._check_shapes(corrupted_targets, prev_inputs, noise_levels) + + # flatten lon/lat dimensions. + prev_inputs = einops.rearrange(prev_inputs, "b lon lat f -> b (lon lat) f") + corrupted_targets = einops.rearrange(corrupted_targets, "b lon lat f -> b (lon lat) f") + + # apply preconditioning functions to target and noise. + scaled_targets = self.precs.c_in(noise_levels)[:, :, None] * corrupted_targets + scaled_noise_levels = self.precs.c_noise(noise_levels) + + # concatenate inputs and targets across features dimension. + grid_features = torch.cat((scaled_targets, prev_inputs), dim=-1) + + # run the model. + preds = self._f_theta(grid_features, scaled_noise_levels) + + # add skip connection. + out = ( + self.precs.c_skip(noise_levels)[:, :, None] * corrupted_targets + + self.precs.c_out(noise_levels)[:, :, None] * preds + ) + + # restore lon/lat dimensions. + out = einops.rearrange(out, "b (lon lat) f -> b lon lat f", lon=self.num_lon) + return out diff --git a/graph_weather/models/gencast/layers/__init__.py b/graph_weather/models/gencast/layers/__init__.py new file mode 100644 index 00000000..ac817186 --- /dev/null +++ b/graph_weather/models/gencast/layers/__init__.py @@ -0,0 +1,5 @@ +"""GenCast layers.""" + +from .decoder import Decoder +from .encoder import Encoder +from .modules import MLP, InteractionNetwork diff --git a/graph_weather/models/gencast/layers/decoder.py b/graph_weather/models/gencast/layers/decoder.py new file mode 100644 index 00000000..7c9238c4 --- /dev/null +++ b/graph_weather/models/gencast/layers/decoder.py @@ -0,0 +1,113 @@ +"""Decoder layer. + +The decoder: +- perform a single message-passing step on mesh2grid using a classical interaction network. +- add a residual connection to the grid nodes. +""" + +import torch + +from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork + + +class Decoder(torch.nn.Module): + """GenCast's decoder.""" + + def __init__( + self, + edges_dim: int, + output_dim: int, + hidden_dims: list[int], + activation_layer: torch.nn.Module = torch.nn.ReLU, + use_layer_norm: bool = True, + ): + """Initialize the Decoder. + + Args: + edges_dim (int): dimension of edges' features. + output_dim (int): dimension of final output. + hidden_dims (list[int]): hidden dimensions of internal MLPs. + activation_layer (torch.nn.Module, optional): activation function of internal MLPs. + Defaults to torch.nn.ReLU. + use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. + Defaults to True. + """ + super().__init__() + + # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent + # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will + # ask the hidden dims just once for each MLP in a module: we don't need to specify them + # individually as arguments, even if the MLPs could have different roles. + self.latent_dim = hidden_dims[-1] + + # Embedders + self.edges_mlp = MLP( + input_dim=edges_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + # Message Passing + self.gnn = InteractionNetwork( + sender_dim=self.latent_dim, + receiver_dim=self.latent_dim, + edge_attr_dim=self.latent_dim, + hidden_dims=hidden_dims, + use_layer_norm=use_layer_norm, + activation_layer=activation_layer, + ) + + # Final grid nodes update + self.grid_mlp_final = MLP( + input_dim=self.latent_dim, + hidden_dims=hidden_dims[:-1] + [output_dim], + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + def forward( + self, + input_mesh_nodes: torch.Tensor, + input_grid_nodes: torch.Tensor, + input_edge_attr: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + """Forward pass. + + Args: + input_mesh_nodes (torch.Tensor): mesh nodes' features. + input_grid_nodes (torch.Tensor): grid nodes' features. + input_edge_attr (torch.Tensor): grid2mesh edges' features. + edge_index (torch.Tensor): edge index tensor. + + Returns: + torch.Tensor: output grid nodes. + """ + if not ( + input_grid_nodes.shape[-1] == self.latent_dim + and input_mesh_nodes.shape[-1] == self.latent_dim + ): + raise ValueError( + "The dimension of grid nodes and mesh nodes' features must be " + "equal to the last hidden dimension." + ) + + # Embedding + edges_emb = self.edges_mlp(input_edge_attr) + + # Message-passing + residual connection + latent_grid_nodes = input_grid_nodes + self.gnn( + x=(input_mesh_nodes, input_grid_nodes), + edge_index=edge_index, + edge_attr=edges_emb, + ) + + # Update grid nodes + latent_grid_nodes = self.grid_mlp_final(latent_grid_nodes) + + return latent_grid_nodes diff --git a/graph_weather/models/gencast/layers/encoder.py b/graph_weather/models/gencast/layers/encoder.py new file mode 100644 index 00000000..e062752e --- /dev/null +++ b/graph_weather/models/gencast/layers/encoder.py @@ -0,0 +1,128 @@ +"""Encoder layer. + +The encoder: +- embeds grid nodes, mesh nodes and g2m edges' features to the latent space. +- perform a single message-passing step using a classical interaction network. +- add a residual connection to the mesh and grid nodes. +""" + +import torch + +from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork + + +class Encoder(torch.nn.Module): + """GenCast's encoder.""" + + def __init__( + self, + grid_dim: int, + mesh_dim: int, + edge_dim: int, + hidden_dims: list[int], + activation_layer: torch.nn.Module = torch.nn.ReLU, + use_layer_norm: bool = True, + ): + """Initialize the Encoder. + + Args: + grid_dim (int): dimension of grid nodes' features. + mesh_dim (int): dimension of mesh nodes' features + edge_dim (int): dimension of g2m edges' features + hidden_dims (list[int]): hidden dimensions of internal MLPs. + activation_layer (torch.nn.Module, optional): activation function of internal MLPs. + Defaults to torch.nn.ReLU. + use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. + Defaults to True. + """ + super().__init__() + + # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent + # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will + # ask the hidden dims just once for each MLP in a module: we don't need to specify them + # individually as arguments, even if the MLPs could have different roles. + self.latent_dim = hidden_dims[-1] + + # Embedders + self.grid_mlp = MLP( + input_dim=grid_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + self.mesh_mlp = MLP( + input_dim=mesh_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + self.edges_mlp = MLP( + input_dim=edge_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + # Message Passing + self.gnn = InteractionNetwork( + sender_dim=self.latent_dim, + receiver_dim=self.latent_dim, + edge_attr_dim=self.latent_dim, + hidden_dims=hidden_dims, + use_layer_norm=use_layer_norm, + activation_layer=activation_layer, + ) + + # Final grid nodes update + self.grid_mlp_final = MLP( + input_dim=self.latent_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + def forward( + self, + input_grid_nodes: torch.Tensor, + input_mesh_nodes: torch.Tensor, + input_edge_attr: torch.Tensor, + edge_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + input_grid_nodes (torch.Tensor): grid nodes' features. + input_mesh_nodes (torch.Tensor): mesh nodes' features. + input_edge_attr (torch.Tensor): grid2mesh edges' features. + edge_index (torch.Tensor): edge index tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: output grid nodes, output mesh nodes. + """ + + # Embedding + grid_emb = self.grid_mlp(input_grid_nodes) + mesh_emb = self.mesh_mlp(input_mesh_nodes) + edges_emb = self.edges_mlp(input_edge_attr) + + # Message-passing + residual connection + latent_mesh_nodes = mesh_emb + self.gnn( + x=(grid_emb, mesh_emb), + edge_index=edge_index, + edge_attr=edges_emb, + ) + + # Update grid nodes + residual connection + latent_grid_nodes = grid_emb + self.grid_mlp_final(grid_emb) + + return latent_grid_nodes, latent_mesh_nodes diff --git a/graph_weather/models/gencast/layers/modules.py b/graph_weather/models/gencast/layers/modules.py new file mode 100644 index 00000000..a3ebed97 --- /dev/null +++ b/graph_weather/models/gencast/layers/modules.py @@ -0,0 +1,151 @@ +"""Modules""" + +import torch +import torch.nn as nn +from torch_geometric.nn import MessagePassing + + +class MLP(nn.Module): + """Classic multi-layer perceptron (MLP) module.""" + + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activation_layer: nn.Module = nn.ReLU, + use_layer_norm: bool = False, + bias: bool = True, + activate_final: bool = False, + ): + """Initialize MLP module. + + Args: + input_dim (int): dimension of input. + hidden_dims (List[int]): list of hidden linear layers dimensions. + activation_layer (torch.nn.Module): activation + function to use. Defaults to torch.nn.ReLU. + use_layer_norm (bool, optional): if Ttrue apply LayerNorm to output. Defaults to False. + bias (bool, optional): if true use bias in linear layers. Defaults to True. + activate_final (bool, optional): whether to apply the activation function to the final + layer. Defaults to False. + """ + super().__init__() + + # Initialize linear layers + self.linears = nn.ModuleList() + self.linears.append(nn.Linear(input_dim, hidden_dims[0], bias=bias)) + for i in range(0, len(hidden_dims) - 1): + self.linears.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=bias)) + + # Initialize activation + self.activation = activation_layer() + + # Initialize layer normalization + self.norm_layer = None + if use_layer_norm: + self.norm_layer = nn.LayerNorm(hidden_dims[-1]) + + self.activate_final = activate_final + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply MLP to input.""" + + for linear in self.linears[:-1]: + x = linear(x) + x = self.activation(x) + + x = self.linears[-1](x) + + if self.activate_final: + x = self.activation(x) + + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class InteractionNetwork(MessagePassing): + """Single message-passing interaction network as described in GenCast. + + This network performs two steps: + 1) message-passing: e'_ij = MLP([e_ij,v_i,v_j]) + 2) aggregation: v'_j = MLP([v_j, sum_i {e'_ij}]) + The underlying graph is a directed graph. + + Note: + We don't need to update edges in GenCast, hence we skip it. + """ + + def __init__( + self, + sender_dim: int, + receiver_dim: int, + edge_attr_dim: int, + hidden_dims: list[int], + use_layer_norm: bool = False, + activation_layer: nn.Module = nn.ReLU, + ): + """Initialize the Interaction Network. + + Args: + sender_dim (int): dimension of sender nodes' features. + receiver_dim (int): dimension of receiver nodes' features. + edge_attr_dim (int): dimension of the edge features. + hidden_dims (list[int]): list of sizes of MLP's linear layers. + use_layer_norm (bool): if true add layer normalization to MLP's last layer. + Defaults to False. + activation_layer (torch.nn.Module): activation function. Defaults to nn.ReLU. + """ + super().__init__(aggr="add", flow="source_to_target") + self.mlp_edges = MLP( + input_dim=sender_dim + receiver_dim + edge_attr_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + self.mlp_nodes = MLP( + input_dim=receiver_dim + hidden_dims[-1], + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + def message(self, x_i, x_j, edge_attr): + """Message-passing step.""" + x = torch.cat((x_i, x_j, edge_attr), dim=-1) + x = self.mlp_edges(x) + return x + + def forward( + self, + x: tuple[torch.Tensor, torch.Tensor], + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the output of the Interaction Network. + + This method processes the input node features and edge attributes through + the network to produce the output node features. + + Args: + x (tuple[torch.Tensor, torch.Tensor]): a tuple containing: + - sender nodes' features (torch.Tensor): features of the sender nodes. + - receiver nodes' features (torch.Tensor): features of the receiver nodes. + edge_index (torch.Tensor): tensor containing edge indices, defining the + connections between nodes. + edge_attr (torch.Tensor): tensor containing edge features, representing + the attributes of each edge. + + Returns: + torch.Tensor: the resulting node features after applying the Interaction Network. + """ + aggr = self.propagate( + edge_index, x=x, edge_attr=edge_attr, size=(x[0].shape[0], x[1].shape[0]) + ) + out = self.mlp_nodes(torch.cat((x[1], aggr), dim=-1)) + return out diff --git a/graph_weather/models/gencast/utils/noise.py b/graph_weather/models/gencast/utils/noise.py index 89c4269d..e44f50b3 100644 --- a/graph_weather/models/gencast/utils/noise.py +++ b/graph_weather/models/gencast/utils/noise.py @@ -2,6 +2,7 @@ import numpy as np import pyshtools as pysh +import torch def generate_isotropic_noise(num_lat, num_samples=1): @@ -25,7 +26,7 @@ def generate_isotropic_noise(num_lat, num_samples=1): for i in range(num_samples): clm = pysh.SHCoeffs.from_random(power) grid[:, :, i] = clm.expand(grid="DH2", extend=False).to_array().transpose() - return grid + return grid.astype(np.float32) def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7): @@ -48,3 +49,35 @@ def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7): sigma_max ** (1 / rho) + u * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho return noise_level + + +class Preconditioner(torch.nn.Module): + """Collection of preconditioning functions. + + These functions are described in Karras (2022), table 1. + """ + + def __init__(self, sigma_data: float = 1): + """Initialize the preconditioning functions. + + Args: + sigma_data (float): Karras suggests 0.5, GenCast 1. Defaults to 1. + """ + super().__init__() + self.sigma_data = sigma_data + + def c_skip(self, sigma): + """Scaling factor for skip connection.""" + return self.sigma_data / (sigma**2 + self.sigma_data**2) + + def c_out(self, sigma): + """Scaling factor for output.""" + return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2) + + def c_in(self, sigma): + """Scaling factor for input.""" + return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) + + def c_noise(self, sigma): + """Scaling factor for noise level.""" + return 1 / 4 * torch.log(sigma) diff --git a/graph_weather/models/gencast/weighted_mse_loss.py b/graph_weather/models/gencast/weighted_mse_loss.py index bded0c53..baa2b21c 100644 --- a/graph_weather/models/gencast/weighted_mse_loss.py +++ b/graph_weather/models/gencast/weighted_mse_loss.py @@ -68,15 +68,18 @@ def _lambda_sigma(self, noise_level): return noise_weights # [batch, 1] def forward( - self, pred: torch.Tensor, target: torch.Tensor, noise_level: torch.Tensor + self, + pred: torch.Tensor, + noise_level: torch.Tensor, + target: torch.Tensor, ) -> torch.Tensor: """Compute the loss. Args: pred (torch.Tensor): prediction of the model [batch, lon, lat, var]. - target (torch.Tensor): target tensor [batch, lon, lat, var]. noise_level (torch.Tensor): noise levels fed to the model for the corresponding predictions [batch, 1]. + target (torch.Tensor): target tensor [batch, lon, lat, var]. Returns: torch.Tensor: weighted MSE loss. @@ -84,7 +87,7 @@ def forward( # check shapes if not (pred.shape == target.shape): raise ValueError( - "redictions and targets must have same shape. The actual shapes " + "Predictions and targets must have same shape. The actual shapes " f"are {pred.shape} and {target.shape}." ) if not (len(pred.shape) == 4): diff --git a/tests/test_model.py b/tests/test_model.py index bc19cc79..1302f9d2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,7 +18,7 @@ generate_isotropic_noise, sample_noise_level, ) -from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss +from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss, Denoiser def test_encoder(): @@ -326,4 +326,36 @@ def test_gencast_loss(): preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) noise_levels = torch.rand((batch_size, 1)) targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - assert loss.forward(preds, targets, noise_levels) is not None + assert loss.forward(preds, noise_levels, targets) is not None + + +def test_gencast_denoiser(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + batch_size = 3 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) + noise_levels = torch.randn((batch_size, 1)) + + with torch.no_grad(): + preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) + + assert not torch.isnan(preds).any() diff --git a/train/gencast_demo.ipynb b/train/gencast_demo.ipynb index 78c09052..d8fd6d20 100644 --- a/train/gencast_demo.ipynb +++ b/train/gencast_demo.ipynb @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -479,6 +479,143 @@ "ax.set_proj_type('ortho') \n", "plt.tight_layout()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Denoiser" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set the dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "from graph_weather.data.gencast_dataloader import GenCastDataset\n", + "\n", + "atmospheric_features = [\n", + " \"geopotential\",\n", + " \"specific_humidity\",\n", + " \"temperature\",\n", + " \"u_component_of_wind\",\n", + " \"v_component_of_wind\",\n", + " \"vertical_velocity\",\n", + "]\n", + "single_features = [\n", + " \"2m_temperature\",\n", + " \"10m_u_component_of_wind\",\n", + " \"10m_v_component_of_wind\",\n", + " \"mean_sea_level_pressure\",\n", + " # \"sea_surface_temperature\",\n", + " \"total_precipitation_12hr\",\n", + "]\n", + "static_features = [\n", + " \"geopotential_at_surface\",\n", + " \"land_sea_mask\",\n", + "]\n", + "\n", + "obs_path = \"gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr\"\n", + "# obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr'\n", + "# obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr'\n", + "\n", + "dataset = GenCastDataset(\n", + " obs_path=obs_path,\n", + " atmospheric_features=atmospheric_features,\n", + " single_features=single_features,\n", + " static_features=static_features,\n", + " max_year=2018,\n", + " time_step=2,\n", + ")\n", + "\n", + "batch_size = 2\n", + "dataloader = DataLoader(dataset, batch_size=batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Predict" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration 0 with training loss 0.1087784692645073\n", + "Iteration 1 with training loss 0.10737644881010056\n", + "Iteration 2 with training loss 0.1024656742811203\n", + "Iteration 3 with training loss 0.10202587395906448\n", + "Iteration 4 with training loss 0.09606097638607025\n", + "Iteration 5 with training loss 0.09462107717990875\n" + ] + } + ], + "source": [ + "import torch\n", + "\n", + "from graph_weather.models.gencast import Denoiser, WeightedMSELoss\n", + "\n", + "denoiser = Denoiser(\n", + " grid_lon=dataset.grid_lon,\n", + " grid_lat=dataset.grid_lat,\n", + " input_features_dim=dataset.input_features_dim,\n", + " output_features_dim=dataset.output_features_dim,\n", + " hidden_dims=[16, 16],\n", + " num_blocks=3,\n", + " num_heads=4,\n", + " splits=0,\n", + " num_hops=1,\n", + " device=torch.device(\"cpu\"),\n", + ")\n", + "\n", + "criterion = WeightedMSELoss(\n", + " grid_lat=torch.tensor(dataset.grid_lat),\n", + " pressure_levels=torch.tensor(dataset.pressure_levels),\n", + " num_atmospheric_features=len(dataset.atmospheric_features),\n", + " single_features_weights=torch.tensor([1.0, 0.1, 0.1, 0.1, 0.1]),\n", + ")\n", + "\n", + "optimizer = torch.optim.AdamW(denoiser.parameters(), lr=1e-3)\n", + "\n", + "for i, data in enumerate(dataloader):\n", + " corrupted_targets, prev_inputs, noise_levels, target_residuals = data\n", + " denoiser.zero_grad()\n", + " preds = denoiser(\n", + " corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels\n", + " )\n", + " loss = criterion(preds, noise_levels, target_residuals)\n", + " loss.backward()\n", + "\n", + " optimizer.step()\n", + " print(f\"Iteration {i} with training loss {float(loss)}.\")\n", + " if i==5:\n", + " break\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": {