Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encoder and Decoder #117

Merged
merged 7 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 27 additions & 21 deletions graph_weather/data/gencast_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,6 @@
class GenCastDataset(Dataset):
"""
Dataset class for GenCast training data.

Args:
obs_path: dataset path.
atmospheric_features: list of features depending on pressure levels.
single_features: list of features not depending on pressure levels.
static_features: list of features not depending on time.
max_year: max year to include in training set. Defaults to 2018.
time_step: time step between predictions.
E.g. 12h steps correspond to time_step = 2 in a 6h dataset. Defaults to 2.
"""

def __init__(
Expand All @@ -42,17 +33,32 @@ def __init__(
):
"""
Initialize the GenCast dataset object.

Args:
obs_path: dataset path.
atmospheric_features: list of features depending on pressure levels.
single_features: list of features not depending on pressure levels.
static_features: list of features not depending on time.
max_year: max year to include in training set. Defaults to 2018.
time_step: time step between predictions.
E.g. 12h steps correspond to time_step = 2 in a 6h dataset. Defaults to 2.
"""
super().__init__()
self.data = xr.open_zarr(obs_path, chunks={})
self.max_year = max_year

self.num_lon = len(self.data["longitude"].values)
self.num_lat = len(self.data["latitude"].values)
self.grid_lon = self.data["longitude"].values
self.grid_lat = self.data["latitude"].values
self.num_lon = len(self.grid_lon)
self.num_lat = len(self.grid_lat)
self.num_vars = len(self.data.keys())
self.pressure_levels = np.array(self.data["level"].values).astype(
np.float32
) # Need them for loss weighting
self.output_features_dim = len(atmospheric_features) * len(self.pressure_levels) + len(
single_features
)
self.input_features_dim = self.output_features_dim + len(static_features) + 4

self.time_step = time_step # e.g. 12h steps correspond to time_step = 2 in a 6h dataset

Expand Down Expand Up @@ -161,7 +167,7 @@ def __getitem__(self, item):

# Concatenate timesteps
inputs = np.concatenate([inputs[0, :, :, :], inputs[1, :, :, :]], axis=-1)
inputs = np.nan_to_num(inputs).astype(np.float32)
prev_inputs = np.nan_to_num(inputs).astype(np.float32)

# Load target data
ds_target_atm = (
Expand All @@ -186,16 +192,16 @@ def __getitem__(self, item):
target_residuals = np.nan_to_num(target_residuals).astype(np.float32)

# Corrupt targets with noise
noise_level = np.array([sample_noise_level()]).astype(np.float32)
noise_levels = np.array([sample_noise_level()]).astype(np.float32)
noise = generate_isotropic_noise(
num_lat=self.num_lat, num_samples=target_residuals.shape[-1]
)
corrupted_residuals = target_residuals + noise_level * noise
corrupted_targets = target_residuals + noise_levels * noise

return (
inputs,
noise_level,
corrupted_residuals,
corrupted_targets,
prev_inputs,
noise_levels,
target_residuals,
)

Expand Down Expand Up @@ -366,7 +372,7 @@ def __getitem__(self, item):
inputs = np.concatenate([batched_inputs_norm, ds_clock], axis=-1)
# Concatenate timesteps
inputs = np.concatenate([inputs[:, 0, :, :, :], inputs[:, 1, :, :, :]], axis=-1)
inputs = np.nan_to_num(inputs).astype(np.float32)
prev_inputs = np.nan_to_num(inputs).astype(np.float32)

# Compute targets residuals
raw_targets = np.concatenate([ds_atm, ds_single], axis=-1)
Expand All @@ -376,13 +382,13 @@ def __getitem__(self, item):

# Corrupt targets with noise
noise_levels = np.zeros((self.batch_size, 1), dtype=np.float32)
corrupted_residuals = np.zeros_like(target_residuals, dtype=np.float32)
corrupted_targets = np.zeros_like(target_residuals, dtype=np.float32)
for b in range(self.batch_size):
noise_level = sample_noise_level()
noise = generate_isotropic_noise(
num_lat=self.num_lat, num_samples=target_residuals.shape[-1]
)
corrupted_residuals[b] = target_residuals[b] + noise_level * noise
corrupted_targets[b] = target_residuals[b] + noise_level * noise
noise_levels[b] = noise_level

return (inputs, noise_levels, corrupted_residuals, target_residuals)
return (corrupted_targets, prev_inputs, noise_levels, target_residuals)
1 change: 1 addition & 0 deletions graph_weather/models/gencast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Main import for GenCast"""

from .denoiser import Denoiser
from .graph.graph_builder import GraphBuilder
from .weighted_mse_loss import WeightedMSELoss
216 changes: 216 additions & 0 deletions graph_weather/models/gencast/denoiser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
"""Denoiser.

The denoiser takes as inputs the previous two timesteps, the corrupted target residual, and the
noise level, and outputs the denoised predictions. It performs the following tasks:
1. Initializes the graph, encoder, processor, and decoder.
2. Computes f_theta as the combination of encoder, processor, and decoder.
3. Preconditions f_theta on the noise levels using the parametrization from Karras et al. (2022).
"""

import einops
import numpy as np
import torch
from torch_geometric.data import Batch

from graph_weather.models.gencast.graph.graph_builder import GraphBuilder
from graph_weather.models.gencast.layers.decoder import Decoder
from graph_weather.models.gencast.layers.encoder import Encoder
from graph_weather.models.gencast.utils.noise import Preconditioner


class Denoiser(torch.nn.Module):
"""GenCast's Denoiser."""

def __init__(
self,
grid_lon: np.ndarray,
grid_lat: np.ndarray,
input_features_dim: int,
output_features_dim: int,
hidden_dims: list[int] = [512, 512],
num_blocks: int = 16,
num_heads: int = 4,
splits: int = 6,
num_hops: int = 6,
device: torch.device = torch.device("cpu"),
):
"""Initialize the Denoiser.

Args:
grid_lon (np.ndarray): array of longitudes.
grid_lat (np.ndarray): array of latitudes.
input_features_dim (int): dimension of the input features for a single timestep.
output_features_dim (int): dimension of the target features.
hidden_dims (list[int], optional): list of dimensions for the hidden layers in the MLPs
used in GenCast. This also determines the latent dimension. Defaults to [512, 512].
num_blocks (int, optional): number of transformer blocks in Processor. Defaults to 16.
num_heads (int, optional): number of heads for each transformer. Defaults to 4.
splits (int, optional): number of time to split the icosphere during graph building.
Defaults to 6.
num_hops (int, optional): the transformes will attention to the (2^num_hops)-neighbours
of each node. Defaults to 6.
device (torch.device, optional): device on which we want to build graph.
Defaults to torch.device("cpu").
"""
super().__init__()
self.num_lon = len(grid_lon)
self.num_lat = len(grid_lat)
self.input_features_dim = input_features_dim
self.output_features_dim = output_features_dim

# Initialize graph
self.graphs = GraphBuilder(
grid_lon=grid_lon,
grid_lat=grid_lat,
splits=splits,
num_hops=num_hops,
device=device,
)

# Initialize Encoder
self.encoder = Encoder(
grid_dim=output_features_dim + 2 * input_features_dim + self.graphs.grid_nodes_dim,
mesh_dim=self.graphs.mesh_nodes_dim,
edge_dim=self.graphs.g2m_edges_dim,
hidden_dims=hidden_dims,
activation_layer=torch.nn.SiLU,
use_layer_norm=True,
)

# Initialize Decoder
self.decoder = Decoder(
edges_dim=self.graphs.m2g_edges_dim,
output_dim=output_features_dim,
hidden_dims=hidden_dims,
activation_layer=torch.nn.SiLU,
use_layer_norm=True,
)

# Initialize preconditioning functions
self.precs = Preconditioner(sigma_data=1.0)

def _check_shapes(self, corrupted_targets, prev_inputs, noise_levels):
batch_size = prev_inputs.shape[0]
exp_inputs_shape = (batch_size, self.num_lon, self.num_lat, 2 * self.input_features_dim)
exp_targets_shape = (batch_size, self.num_lon, self.num_lat, self.output_features_dim)
exp_noise_shape = (batch_size, 1)

if not all(
[
corrupted_targets.shape == exp_targets_shape,
prev_inputs.shape == exp_inputs_shape,
noise_levels.shape == exp_noise_shape,
]
):
raise ValueError(
"The shapes of the input tensors don't match with the initialization parameters: "
f"expected {exp_inputs_shape} for prev_inputs, {exp_targets_shape} for targets and "
f"{exp_noise_shape} for noise_levels."
)

def _run_encoder(self, grid_features):
# build big graph with batch_size disconnected copies of the graph, with features [(b n) f].
batch_size = grid_features.shape[0]
g2m_batched = Batch.from_data_list([self.graphs.g2m_graph] * batch_size)

# load features.
grid_features = einops.rearrange(grid_features, "b n f -> (b n) f")
input_grid_nodes = torch.cat([grid_features, g2m_batched["grid_nodes"].x], dim=-1).type(
torch.float32
)
input_mesh_nodes = g2m_batched["mesh_nodes"].x
input_edge_attr = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_attr
edge_index = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_index

# run the encoder.
latent_grid_nodes, latent_mesh_nodes = self.encoder(
input_grid_nodes=input_grid_nodes,
input_mesh_nodes=input_mesh_nodes,
input_edge_attr=input_edge_attr,
edge_index=edge_index,
)

# restore nodes dimension: [b, n, f]
latent_grid_nodes = einops.rearrange(latent_grid_nodes, "(b n) f -> b n f", b=batch_size)
latent_mesh_nodes = einops.rearrange(latent_mesh_nodes, "(b n) f -> b n f", b=batch_size)
return latent_grid_nodes, latent_mesh_nodes

def _run_decoder(self, latent_mesh_nodes, latent_grid_nodes):
# build big graph with batch_size disconnected copies of the graph, with features [(b n) f].
batch_size = latent_mesh_nodes.shape[0]
m2g_batched = Batch.from_data_list([self.graphs.m2g_graph] * batch_size)

# load features.
input_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f")
input_grid_nodes = einops.rearrange(latent_grid_nodes, "b n f -> (b n) f")
input_edge_attr = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_attr
edge_index = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_index

# run the decoder.
output_grid_nodes = self.decoder(
input_mesh_nodes=input_mesh_nodes,
input_grid_nodes=input_grid_nodes,
input_edge_attr=input_edge_attr,
edge_index=edge_index,
)

# restore nodes dimension: [b, n, f]
output_grid_nodes = einops.rearrange(output_grid_nodes, "(b n) f -> b n f", b=batch_size)
return output_grid_nodes

def _run_processor(self, latent_mesh_nodes, noise_levels):
# TODO: add processor.
return latent_mesh_nodes

def _f_theta(self, grid_features, noise_levels):
# run encoder, processor and decoder.
latent_grid_nodes, latent_mesh_nodes = self._run_encoder(grid_features)
latent_mesh_nodes = self._run_processor(latent_mesh_nodes, noise_levels)
output_grid_nodes = self._run_decoder(latent_mesh_nodes, latent_grid_nodes)
return output_grid_nodes

def forward(
self, corrupted_targets: torch.Tensor, prev_inputs: torch.Tensor, noise_levels: torch.Tensor
) -> torch.Tensor:
"""Compute the denoiser output.

The denoiser is a version of the (encoder, processor, decoder)-model (called f_theta),
preconditioned on the noise levels, as described below:

D(Z, X, sigma) := c_skip(sigma)Z + c_out(sigma) * f_theta(c_in(sigma)Z, X, c_noise(sigma)),

where Z is the corrupted target, X is the previous two timesteps concatenated and sigma is
the noise level used for Z's corruption.

Args:
corrupted_targets (torch.Tensor): the target residuals corrupted by noise.
prev_inputs (torch.Tensor): the previous two timesteps concatenated across the features'
dimension.
noise_levels (torch.Tensor): the noise level used for corruption.
"""
# check shapes.
self._check_shapes(corrupted_targets, prev_inputs, noise_levels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like this


# flatten lon/lat dimensions.
prev_inputs = einops.rearrange(prev_inputs, "b lon lat f -> b (lon lat) f")
corrupted_targets = einops.rearrange(corrupted_targets, "b lon lat f -> b (lon lat) f")

# apply preconditioning functions to target and noise.
scaled_targets = self.precs.c_in(noise_levels)[:, :, None] * corrupted_targets
scaled_noise_levels = self.precs.c_noise(noise_levels)

# concatenate inputs and targets across features dimension.
grid_features = torch.cat((scaled_targets, prev_inputs), dim=-1)

# run the model.
preds = self._f_theta(grid_features, scaled_noise_levels)

# add skip connection.
out = (
self.precs.c_skip(noise_levels)[:, :, None] * corrupted_targets
+ self.precs.c_out(noise_levels)[:, :, None] * preds
)

# restore lon/lat dimensions.
out = einops.rearrange(out, "b (lon lat) f -> b lon lat f", lon=self.num_lon)
return out
5 changes: 5 additions & 0 deletions graph_weather/models/gencast/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""GenCast layers."""

from .decoder import Decoder
from .encoder import Encoder
from .modules import MLP, InteractionNetwork
Loading