From 5bc3b5c727d526f15ba2838455b95210161d29f6 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Mon, 10 Jun 2024 12:28:56 +0200 Subject: [PATCH 1/7] Add mlp and interaction network --- .../models/gencast/layers/__init__.py | 1 + .../models/gencast/layers/decoder.py | 0 .../models/gencast/layers/denoiser.py | 0 .../models/gencast/layers/encoder.py | 58 +++++++ .../models/gencast/layers/modules.py | 151 ++++++++++++++++++ train/gencast_demo.ipynb | 70 ++++++++ 6 files changed, 280 insertions(+) create mode 100644 graph_weather/models/gencast/layers/__init__.py create mode 100644 graph_weather/models/gencast/layers/decoder.py create mode 100644 graph_weather/models/gencast/layers/denoiser.py create mode 100644 graph_weather/models/gencast/layers/encoder.py create mode 100644 graph_weather/models/gencast/layers/modules.py diff --git a/graph_weather/models/gencast/layers/__init__.py b/graph_weather/models/gencast/layers/__init__.py new file mode 100644 index 00000000..51679cee --- /dev/null +++ b/graph_weather/models/gencast/layers/__init__.py @@ -0,0 +1 @@ +"""GenCast layers.""" \ No newline at end of file diff --git a/graph_weather/models/gencast/layers/decoder.py b/graph_weather/models/gencast/layers/decoder.py new file mode 100644 index 00000000..e69de29b diff --git a/graph_weather/models/gencast/layers/denoiser.py b/graph_weather/models/gencast/layers/denoiser.py new file mode 100644 index 00000000..e69de29b diff --git a/graph_weather/models/gencast/layers/encoder.py b/graph_weather/models/gencast/layers/encoder.py new file mode 100644 index 00000000..5c42b3cf --- /dev/null +++ b/graph_weather/models/gencast/layers/encoder.py @@ -0,0 +1,58 @@ +"""Encoder layer. + +The encoder: +- embeds grid nodes, mesh nodes and g2m edges' features to the latent space. +- perform a single message-passing step using a classical interaction network. +- add a residual connection to the mesh nodes. +""" +import torch +from torch.nn.modules import + +class Encoder(torch.nn.Module): + def __init__(self, + grid_nodes_input_dim, + mesh_nodes_input_dim, + edge_attr_input_dim, + mlp_hidden_dim, + latent_dim, + mlp_norm_type="LayerNorm", + mlp_act_function="swish"): + super().__init__() + + # Embedders + self.grid_nodes_mlp = MLP([grid_nodes_input_dim, mlp_hidden_dim, latent_dim], + norm = mlp_norm_type, + act=mlp_act_function) + self.mesh_nodes_mlp = MLP([mesh_nodes_input_dim, mlp_hidden_dim, latent_dim], + norm = mlp_norm_type, + act=mlp_act_function) + self.edge_attr_mlp = MLP([edge_attr_input_dim, mlp_hidden_dim, latent_dim], + norm = mlp_norm_type, + act=mlp_act_function) + + # Message Passing + self.conv = InteractionNetwork(senders_input_dim=latent_dim, + receivers_input_dim=latent_dim, + edges_input_dim=latent_dim, + hidden_dim=mlp_hidden_dim, + output_dim=latent_dim, + mlp_norm_type=mlp_norm_type, + mlp_act_function=mlp_act_function, + ) + + self.grid_nodes_mlp_2 = MLP([latent_dim, mlp_hidden_dim, latent_dim], + norm=mlp_norm_type, + act=mlp_act_function) + + def forward(self, input_grid_nodes, input_mesh_nodes, input_edge_attr, edge_index): + # Embedding + grid_nodes_emb = self.grid_nodes_mlp(input_grid_nodes) + mesh_nodes_emb = self.mesh_nodes_mlp(input_mesh_nodes) + edge_attr_emb = self.edge_attr_mlp(input_edge_attr) + + latent_mesh_nodes = mesh_nodes_emb + self.conv(x=(grid_nodes_emb, mesh_nodes_emb), + edge_index=edge_index, + edge_attr=edge_attr_emb) + latent_grid_nodes = grid_nodes_emb + self.grid_nodes_mlp_2(grid_nodes_emb) + # TODO: Ask why we need to update eg2m, since we don't use them later + return latent_mesh_nodes, latent_grid_nodes diff --git a/graph_weather/models/gencast/layers/modules.py b/graph_weather/models/gencast/layers/modules.py new file mode 100644 index 00000000..4cb968ca --- /dev/null +++ b/graph_weather/models/gencast/layers/modules.py @@ -0,0 +1,151 @@ +"""Modules""" + +import torch +import torch.nn as nn +from torch_geometric.nn import MessagePassing + + +class MLP(nn.Module): + """Classic multi-layer perceptron (MLP) module.""" + + def __init__( + self, + input_dim: int, + hidden_dims: list[int], + activation_layer: nn.Module = nn.ReLU, + use_layer_norm: bool = False, + bias: bool = True, + activate_final: bool = False, + ): + """Initialize MLP module. + + Args: + input_dim (int): dimension of input. + hidden_dims (List[int]): list of hidden linear layers dimensions. + activation_layer (torch.nn.Module): activation + function to use. Defaults to torch.nn.ReLU. + use_layer_norm (bool, optional): if Ttrue apply LayerNorm to output. Defaults to False. + bias (bool, optional): if true use bias in linear layers. Defaults to True. + activate_final (bool, optional): whether to apply the activation function to the final + layer. Defaults to False. + """ + super().__init__() + + # Initialize linear layers + self.linears = nn.ModuleList() + self.linears.append(nn.Linear(input_dim, hidden_dims[0], bias=bias)) + for i in range(0, len(hidden_dims) - 1): + self.linears.append(nn.Linear(hidden_dims[i], hidden_dims[i + 1], bias=bias)) + + # Initialize activation + self.activation = activation_layer() + + # Initialize layer normalization + self.norm_layer = None + if use_layer_norm: + self.norm_layer = nn.LayerNorm(hidden_dims[-1]) + + self.activate_final = activate_final + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Apply MLP to input.""" + + for linear in self.linears[:-1]: + x = linear(x) + x = self.activation(x) + + x = self.linears[-1](x) + + if self.activate_final: + x = self.activation(x) + + if self.norm_layer is not None: + x = self.norm_layer(x) + return x + + +class InteractionNetwork(MessagePassing): + """Single message-passing interaction network as described in GenCast. + + This network performs two steps: + 1) message-passing: e'_ij = MLP([e_ij,v_i,v_j]) + 2) aggregation: v'_j = MLP([v_j, sum_i {e'_ij}]) + The underlying graph is a directed graph. + + Note: + We don't need to update edges in GenCast, hence we skip that. + """ + + def __init__( + self, + sender_dim: int, + receiver_dim: int, + edge_attr_dim: int, + hidden_dims: list[int], + use_layer_norm: bool = False, + activation_layer: nn.Module = nn.ReLU, + ): + """Initialize the Interaction Network. + + Args: + sender_dim (int): dimension of sender nodes' features. + receiver_dim (int): dimension of receiver nodes' features. + edge_attr_dim (int): dimension of the edge features. + hidden_dims (list[int]): list of sizes of MLP's linear layers. + use_layer_norm (bool): if true add layer normalization to MLP's last layer. + Defaults to False. + activation_layer (torch.nn.Module): activation function. Defaults to nn.ReLU. + """ + super().__init__(aggr="add", flow="source_to_target") + self.mlp_edges = MLP( + input_dim=sender_dim + receiver_dim + edge_attr_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + self.mlp_nodes = MLP( + input_dim=receiver_dim + hidden_dims[-1], + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + def message(self, x_i, x_j, edge_attr): + """Message-passing step.""" + x = torch.cat((x_i, x_j, edge_attr), dim=-1) + x = self.mlp_edges(x) + return x + + def forward( + self, + x: tuple[torch.Tensor, torch.Tensor], + edge_index: torch.Tensor, + edge_attr: torch.Tensor, + ) -> torch.Tensor: + """ + Compute the output of the Interaction Network. + + This method processes the input node features and edge attributes through + the network to produce the output node features. + + Args: + x (tuple[torch.Tensor, torch.Tensor]): a tuple containing: + - sender nodes' features (torch.Tensor): features of the sender nodes. + - receiver nodes' features (torch.Tensor): features of the receiver nodes. + edge_index (torch.Tensor): tensor containing edge indices, defining the + connections between nodes. + edge_attr (torch.Tensor): tensor containing edge features, representing + the attributes of each edge. + + Returns: + torch.Tensor: the resulting node features after applying the Interaction Network. + """ + aggr = self.propagate( + edge_index, x=x, edge_attr=edge_attr, size=(x[0].shape[0], x[1].shape[0]) + ) + out = self.mlp_nodes(torch.cat((x[1], aggr), dim=-1)) + return out diff --git a/train/gencast_demo.ipynb b/train/gencast_demo.ipynb index 78c09052..a970d52b 100644 --- a/train/gencast_demo.ipynb +++ b/train/gencast_demo.ipynb @@ -479,6 +479,76 @@ "ax.set_proj_type('ortho') \n", "plt.tight_layout()" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Encoder and Decoder" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/gbruno16/Code/graph_weather_dev\n" + ] + } + ], + "source": [ + "%cd ../" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "\n", + "from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork\n", + "\n", + "nn = MLP(\n", + " input_dim=5, hidden_dims=[10, 5, 4, 2], activation_layer=torch.nn.GELU, use_layer_norm=True\n", + ")\n", + "\n", + "edge_index = torch.tensor(\n", + " [[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4], [0, 1, 2, 0, 1, 2, 0, 1, 1, 2, 0]], dtype=torch.long\n", + ")\n", + "\n", + "sender_nodes = 5\n", + "receiver_nodes = 3\n", + "sender_dim = 10\n", + "receiver_dim = 11\n", + "edge_attr_dim = 12\n", + "hidden_dims = [13, 14]\n", + "\n", + "sender_features = torch.rand((sender_nodes, sender_dim))\n", + "receiver_features = torch.rand((receiver_nodes, receiver_dim))\n", + "edge_attr = torch.rand((len(edge_index[0]), edge_attr_dim))\n", + "gnn = InteractionNetwork(\n", + " sender_dim=sender_dim,\n", + " receiver_dim=receiver_dim,\n", + " edge_attr_dim=edge_attr_dim,\n", + " hidden_dims=hidden_dims,\n", + ")\n", + "\n", + "assert gnn((sender_features, receiver_features), edge_index, edge_attr).shape == (receiver_nodes,\n", + " hidden_dims[-1])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { From 06d03eefd6ef797c92a71939e3f4377480facdf0 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Wed, 12 Jun 2024 11:43:55 +0200 Subject: [PATCH 2/7] Add denoiser --- graph_weather/data/gencast_dataloader.py | 28 +- graph_weather/models/gencast/__init__.py | 1 + graph_weather/models/gencast/denoiser.py | 207 ++++++++++ .../models/gencast/layers/__init__.py | 6 +- .../models/gencast/layers/decoder.py | 113 ++++++ .../models/gencast/layers/denoiser.py | 0 .../models/gencast/layers/encoder.py | 158 ++++++-- .../models/gencast/layers/modules.py | 2 +- graph_weather/models/gencast/utils/noise.py | 33 ++ train/gencast_demo.ipynb | 381 +++++++++++++++++- 10 files changed, 860 insertions(+), 69 deletions(-) create mode 100644 graph_weather/models/gencast/denoiser.py delete mode 100644 graph_weather/models/gencast/layers/denoiser.py diff --git a/graph_weather/data/gencast_dataloader.py b/graph_weather/data/gencast_dataloader.py index 1aa2f22a..93e9e221 100644 --- a/graph_weather/data/gencast_dataloader.py +++ b/graph_weather/data/gencast_dataloader.py @@ -20,15 +20,6 @@ class GenCastDataset(Dataset): """ Dataset class for GenCast training data. - - Args: - obs_path: dataset path. - atmospheric_features: list of features depending on pressure levels. - single_features: list of features not depending on pressure levels. - static_features: list of features not depending on time. - max_year: max year to include in training set. Defaults to 2018. - time_step: time step between predictions. - E.g. 12h steps correspond to time_step = 2 in a 6h dataset. Defaults to 2. """ def __init__( @@ -42,17 +33,32 @@ def __init__( ): """ Initialize the GenCast dataset object. + + Args: + obs_path: dataset path. + atmospheric_features: list of features depending on pressure levels. + single_features: list of features not depending on pressure levels. + static_features: list of features not depending on time. + max_year: max year to include in training set. Defaults to 2018. + time_step: time step between predictions. + E.g. 12h steps correspond to time_step = 2 in a 6h dataset. Defaults to 2. """ super().__init__() self.data = xr.open_zarr(obs_path, chunks={}) self.max_year = max_year - self.num_lon = len(self.data["longitude"].values) - self.num_lat = len(self.data["latitude"].values) + self.grid_lon = self.data["longitude"].values + self.grid_lat = self.data["latitude"].values + self.num_lon = len(self.grid_lon) + self.num_lat = len(self.grid_lat) self.num_vars = len(self.data.keys()) self.pressure_levels = np.array(self.data["level"].values).astype( np.float32 ) # Need them for loss weighting + self.output_features_dim = len(atmospheric_features) * len(self.pressure_levels) + len( + single_features + ) + self.input_features_dim = self.output_features_dim + len(static_features) self.time_step = time_step # e.g. 12h steps correspond to time_step = 2 in a 6h dataset diff --git a/graph_weather/models/gencast/__init__.py b/graph_weather/models/gencast/__init__.py index 1edcbc86..b83b9290 100644 --- a/graph_weather/models/gencast/__init__.py +++ b/graph_weather/models/gencast/__init__.py @@ -1,4 +1,5 @@ """Main import for GenCast""" +from .denoiser import Denoiser from .graph.graph_builder import GraphBuilder from .weighted_mse_loss import WeightedMSELoss diff --git a/graph_weather/models/gencast/denoiser.py b/graph_weather/models/gencast/denoiser.py new file mode 100644 index 00000000..7038f118 --- /dev/null +++ b/graph_weather/models/gencast/denoiser.py @@ -0,0 +1,207 @@ +"""Denoiser.""" + +import einops +import numpy as np +import torch +from torch_geometric.data import Batch + +from graph_weather.models.gencast.graph.graph_builder import GraphBuilder +from graph_weather.models.gencast.layers.decoder import Decoder +from graph_weather.models.gencast.layers.encoder import Encoder +from graph_weather.models.gencast.utils.noise import Preconditioner + + +class Denoiser(torch.nn.Module): + """GenCast's Denoiser.""" + + def __init__( + self, + grid_lon: np.ndarray, + grid_lat: np.ndarray, + input_features_dim: int, + output_features_dim: int, + hidden_dims: list[int] = [512, 512], + num_blocks: int = 16, + num_heads: int = 4, + splits: int = 6, + num_hops: int = 6, + device: torch.device = torch.device("cpu"), + ): + """Initialize the Denoiser. + + Args: + grid_lon (np.ndarray): array of longitudes. + grid_lat (np.ndarray): array of latitudes. + input_features_dim (int): dimension of the input features for a single timestep. + output_features_dim (int): dimension of the target features. + hidden_dims (list[int], optional): list of dimensions for the hidden layers in the MLPs + used in GenCast. This also determines the latent dimension. Defaults to [512, 512]. + num_blocks (int, optional): number of transformer blocks in Processor. Defaults to 16. + num_heads (int, optional): number of heads for each transformer. Defaults to 4. + splits (int, optional): number of time to split the icosphere during graph building. + Defaults to 6. + num_hops (int, optional): the transformes will attention to the (2^num_hops)-neighbours + of each node. Defaults to 6. + device (torch.device, optional): device on which we want to build graph. + Defaults to torch.device("cpu"). + """ + super().__init__() + self.num_lon = len(grid_lon) + self.num_lat = len(grid_lat) + self.input_features_dim = input_features_dim + self.output_features_dim = output_features_dim + + # Initialize graph + self.graphs = GraphBuilder( + grid_lon=grid_lon, + grid_lat=grid_lat, + splits=splits, + num_hops=num_hops, + device=device, + ) + + # Initialize Encoder + self.encoder = Encoder( + grid_dim=self.graphs.grid_nodes_dim + 2 * input_features_dim, + mesh_dim=self.graphs.mesh_nodes_dim, + edge_dim=self.graphs.g2m_edges_dim, + hidden_dims=hidden_dims, + activation_layer=torch.nn.SiLU, + use_layer_norm=True, + ) + + # Initialize Decoder + self.decoder = Decoder( + edges_dim=self.graphs.m2g_edges_dim, + output_dim=output_features_dim, + hidden_dims=hidden_dims, + activation_layer=torch.nn.SiLU, + use_layer_norm=True, + ) + + # Initialize preconditioning functions + self.precs = Preconditioner(sigma_data=1.0) + + def _check_shapes(self, corrupted_targets, prev_inputs, noise_levels): + batch_size = prev_inputs.shape[0] + exp_inputs_shape = (batch_size, self.num_lon, self.num_lat, 2 * self.input_features_dim) + exp_targets_shape = (batch_size, self.num_lon, self.num_lat, self.output_features_dim) + exp_noise_shape = (batch_size, 1) + + if not all( + corrupted_targets.shape == exp_targets_shape, + prev_inputs.shape == exp_inputs_shape, + noise_levels.shape == exp_noise_shape, + ): + raise ValueError( + "The shapes of the input tensors don't match with the initialization parameters: " + f"expected {exp_inputs_shape} for prev_inputs, {exp_targets_shape} for targets and " + f"{exp_noise_shape} for noise_levels." + ) + + def _run_encoder(self, grid_features): + # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. + batch_size = grid_features.shape[0] + g2m_batched = Batch.from_data_list([self.graphs.g2m_graph] * batch_size) + + # load features. + grid_features = einops.rearrange(grid_features, "b n f -> (b n) f") + input_grid_nodes = torch.cat([grid_features, g2m_batched["grid_nodes"].x], dim=-1).type( + torch.float32 + ) + input_mesh_nodes = g2m_batched["mesh_nodes"].x + input_edge_attr = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_attr + edge_index = g2m_batched["grid_nodes", "to", "mesh_nodes"].edge_index + + # run the encoder. + latent_grid_nodes, latent_mesh_nodes = self.encoder( + input_grid_nodes=input_grid_nodes, + input_mesh_nodes=input_mesh_nodes, + input_edge_attr=input_edge_attr, + edge_index=edge_index, + ) + + # restore nodes dimension: [b, n, f] + latent_grid_nodes = einops.rearrange(latent_grid_nodes, "(b n) f -> b n f", b=batch_size) + latent_mesh_nodes = einops.rearrange(latent_mesh_nodes, "(b n) f -> b n f", b=batch_size) + return latent_grid_nodes, latent_mesh_nodes + + def _run_decoder(self, latent_mesh_nodes, latent_grid_nodes): + # build big graph with batch_size disconnected copies of the graph, with features [(b n) f]. + batch_size = latent_mesh_nodes.shape[0] + m2g_batched = Batch.from_data_list([self.graphs.m2g_graph] * batch_size) + + # load features. + input_mesh_nodes = einops.rearrange(latent_mesh_nodes, "b n f -> (b n) f") + input_grid_nodes = einops.rearrange(latent_grid_nodes, "b n f -> (b n) f") + input_edge_attr = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_attr + edge_index = m2g_batched["mesh_nodes", "to", "grid_nodes"].edge_index + + # run the decoder. + output_grid_nodes = self.decoder( + input_mesh_nodes=input_mesh_nodes, + input_grid_nodes=input_grid_nodes, + input_edge_attr=input_edge_attr, + edge_index=edge_index, + ) + + # restore nodes dimension: [b, n, f] + output_grid_nodes = einops.rearrange(output_grid_nodes, "(b n) f -> b n f", b=batch_size) + return output_grid_nodes + + def _run_processor(self, latent_mesh_nodes, noise_levels): + # TODO: add processor. + return latent_mesh_nodes + + def _f_theta(self, grid_features, noise_levels): + # run encoder, processor and decoder. + latent_grid_nodes, latent_mesh_nodes = self._run_encoder(grid_features) + latent_mesh_nodes = self._run_processor(latent_mesh_nodes, noise_levels) + output_grid_nodes = self._run_decoder(latent_mesh_nodes, latent_grid_nodes) + return output_grid_nodes + + def forward( + self, corrupted_targets: torch.Tensor, prev_inputs: torch.Tensor, noise_levels: torch.Tensor + ) -> torch.Tensor: + """Compute the denoiser output. + + The denoiser is a version of the (encoder, processor, decoder)-model (called f_theta), + preconditioned on the noise levels, as described below: + + D(Z, X, sigma) := c_skip(sigma)Z + c_out(sigma) * f_theta(c_in(sigma)Z, X, c_noise(sigma)), + + where Z is the corrupted target, X is the previous two timesteps concatenated and sigma is + the noise level used for Z's corruption. + + Args: + corrupted_targets (torch.Tensor): the target residuals corrupted by noise. + prev_inputs (torch.Tensor): the previous two timesteps concatenated across the features' + dimension. + noise_levels (torch.Tensor): the noise level used for corruption. + """ + # check shapes. + self._check_shapes(corrupted_targets, prev_inputs, noise_levels) + + # flatten lon/lat dimensions. + prev_inputs = einops.rearrange(prev_inputs, "b lon lat f -> b (lon lat) f") + corrupted_targets = einops.rearrange(corrupted_targets, "b lon lat f -> b (lon lat) f") + + # apply preconditioning functions to target and noise. + scaled_targets = self.precs.c_in(noise_levels)[:, :, None] * corrupted_targets + scaled_noise_levels = self.precs.c_noise(noise_levels) + + # concatenate inputs and targets across features dimension. + grid_features = torch.cat((scaled_targets, prev_inputs), dim=-1) + + # run the model. + preds = self._f_theta(grid_features, scaled_noise_levels) + + # add skip connection. + out = ( + self.precs.c_skip(noise_levels)[:, :, None] * corrupted_targets + + self.precs.c_out(noise_levels)[:, :, None] * preds + ) + + # restore lon/lat dimensions. + out = einops.rearrange(out, "b (lon lat) f -> b lon lat f", lon=self.num_lon) + return out diff --git a/graph_weather/models/gencast/layers/__init__.py b/graph_weather/models/gencast/layers/__init__.py index 51679cee..ac817186 100644 --- a/graph_weather/models/gencast/layers/__init__.py +++ b/graph_weather/models/gencast/layers/__init__.py @@ -1 +1,5 @@ -"""GenCast layers.""" \ No newline at end of file +"""GenCast layers.""" + +from .decoder import Decoder +from .encoder import Encoder +from .modules import MLP, InteractionNetwork diff --git a/graph_weather/models/gencast/layers/decoder.py b/graph_weather/models/gencast/layers/decoder.py index e69de29b..7c9238c4 100644 --- a/graph_weather/models/gencast/layers/decoder.py +++ b/graph_weather/models/gencast/layers/decoder.py @@ -0,0 +1,113 @@ +"""Decoder layer. + +The decoder: +- perform a single message-passing step on mesh2grid using a classical interaction network. +- add a residual connection to the grid nodes. +""" + +import torch + +from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork + + +class Decoder(torch.nn.Module): + """GenCast's decoder.""" + + def __init__( + self, + edges_dim: int, + output_dim: int, + hidden_dims: list[int], + activation_layer: torch.nn.Module = torch.nn.ReLU, + use_layer_norm: bool = True, + ): + """Initialize the Decoder. + + Args: + edges_dim (int): dimension of edges' features. + output_dim (int): dimension of final output. + hidden_dims (list[int]): hidden dimensions of internal MLPs. + activation_layer (torch.nn.Module, optional): activation function of internal MLPs. + Defaults to torch.nn.ReLU. + use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. + Defaults to True. + """ + super().__init__() + + # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent + # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will + # ask the hidden dims just once for each MLP in a module: we don't need to specify them + # individually as arguments, even if the MLPs could have different roles. + self.latent_dim = hidden_dims[-1] + + # Embedders + self.edges_mlp = MLP( + input_dim=edges_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + # Message Passing + self.gnn = InteractionNetwork( + sender_dim=self.latent_dim, + receiver_dim=self.latent_dim, + edge_attr_dim=self.latent_dim, + hidden_dims=hidden_dims, + use_layer_norm=use_layer_norm, + activation_layer=activation_layer, + ) + + # Final grid nodes update + self.grid_mlp_final = MLP( + input_dim=self.latent_dim, + hidden_dims=hidden_dims[:-1] + [output_dim], + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + def forward( + self, + input_mesh_nodes: torch.Tensor, + input_grid_nodes: torch.Tensor, + input_edge_attr: torch.Tensor, + edge_index: torch.Tensor, + ) -> torch.Tensor: + """Forward pass. + + Args: + input_mesh_nodes (torch.Tensor): mesh nodes' features. + input_grid_nodes (torch.Tensor): grid nodes' features. + input_edge_attr (torch.Tensor): grid2mesh edges' features. + edge_index (torch.Tensor): edge index tensor. + + Returns: + torch.Tensor: output grid nodes. + """ + if not ( + input_grid_nodes.shape[-1] == self.latent_dim + and input_mesh_nodes.shape[-1] == self.latent_dim + ): + raise ValueError( + "The dimension of grid nodes and mesh nodes' features must be " + "equal to the last hidden dimension." + ) + + # Embedding + edges_emb = self.edges_mlp(input_edge_attr) + + # Message-passing + residual connection + latent_grid_nodes = input_grid_nodes + self.gnn( + x=(input_mesh_nodes, input_grid_nodes), + edge_index=edge_index, + edge_attr=edges_emb, + ) + + # Update grid nodes + latent_grid_nodes = self.grid_mlp_final(latent_grid_nodes) + + return latent_grid_nodes diff --git a/graph_weather/models/gencast/layers/denoiser.py b/graph_weather/models/gencast/layers/denoiser.py deleted file mode 100644 index e69de29b..00000000 diff --git a/graph_weather/models/gencast/layers/encoder.py b/graph_weather/models/gencast/layers/encoder.py index 5c42b3cf..e062752e 100644 --- a/graph_weather/models/gencast/layers/encoder.py +++ b/graph_weather/models/gencast/layers/encoder.py @@ -3,56 +3,126 @@ The encoder: - embeds grid nodes, mesh nodes and g2m edges' features to the latent space. - perform a single message-passing step using a classical interaction network. -- add a residual connection to the mesh nodes. +- add a residual connection to the mesh and grid nodes. """ + import torch -from torch.nn.modules import + +from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork + class Encoder(torch.nn.Module): - def __init__(self, - grid_nodes_input_dim, - mesh_nodes_input_dim, - edge_attr_input_dim, - mlp_hidden_dim, - latent_dim, - mlp_norm_type="LayerNorm", - mlp_act_function="swish"): + """GenCast's encoder.""" + + def __init__( + self, + grid_dim: int, + mesh_dim: int, + edge_dim: int, + hidden_dims: list[int], + activation_layer: torch.nn.Module = torch.nn.ReLU, + use_layer_norm: bool = True, + ): + """Initialize the Encoder. + + Args: + grid_dim (int): dimension of grid nodes' features. + mesh_dim (int): dimension of mesh nodes' features + edge_dim (int): dimension of g2m edges' features + hidden_dims (list[int]): hidden dimensions of internal MLPs. + activation_layer (torch.nn.Module, optional): activation function of internal MLPs. + Defaults to torch.nn.ReLU. + use_layer_norm (bool, optional): if true add a LayerNorm at the end of each MLP. + Defaults to True. + """ super().__init__() + # All the MLPs in GenCast have same hidden and output dims. Hence, the embedding latent + # dimension and the MLPs' output dimension are the same. Moreover, for simplicity, we will + # ask the hidden dims just once for each MLP in a module: we don't need to specify them + # individually as arguments, even if the MLPs could have different roles. + self.latent_dim = hidden_dims[-1] + # Embedders - self.grid_nodes_mlp = MLP([grid_nodes_input_dim, mlp_hidden_dim, latent_dim], - norm = mlp_norm_type, - act=mlp_act_function) - self.mesh_nodes_mlp = MLP([mesh_nodes_input_dim, mlp_hidden_dim, latent_dim], - norm = mlp_norm_type, - act=mlp_act_function) - self.edge_attr_mlp = MLP([edge_attr_input_dim, mlp_hidden_dim, latent_dim], - norm = mlp_norm_type, - act=mlp_act_function) - + self.grid_mlp = MLP( + input_dim=grid_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + self.mesh_mlp = MLP( + input_dim=mesh_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + self.edges_mlp = MLP( + input_dim=edge_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + # Message Passing - self.conv = InteractionNetwork(senders_input_dim=latent_dim, - receivers_input_dim=latent_dim, - edges_input_dim=latent_dim, - hidden_dim=mlp_hidden_dim, - output_dim=latent_dim, - mlp_norm_type=mlp_norm_type, - mlp_act_function=mlp_act_function, - ) - - self.grid_nodes_mlp_2 = MLP([latent_dim, mlp_hidden_dim, latent_dim], - norm=mlp_norm_type, - act=mlp_act_function) - - def forward(self, input_grid_nodes, input_mesh_nodes, input_edge_attr, edge_index): + self.gnn = InteractionNetwork( + sender_dim=self.latent_dim, + receiver_dim=self.latent_dim, + edge_attr_dim=self.latent_dim, + hidden_dims=hidden_dims, + use_layer_norm=use_layer_norm, + activation_layer=activation_layer, + ) + + # Final grid nodes update + self.grid_mlp_final = MLP( + input_dim=self.latent_dim, + hidden_dims=hidden_dims, + activation_layer=activation_layer, + use_layer_norm=use_layer_norm, + bias=True, + activate_final=False, + ) + + def forward( + self, + input_grid_nodes: torch.Tensor, + input_mesh_nodes: torch.Tensor, + input_edge_attr: torch.Tensor, + edge_index: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Forward pass. + + Args: + input_grid_nodes (torch.Tensor): grid nodes' features. + input_mesh_nodes (torch.Tensor): mesh nodes' features. + input_edge_attr (torch.Tensor): grid2mesh edges' features. + edge_index (torch.Tensor): edge index tensor. + + Returns: + tuple[torch.Tensor, torch.Tensor]: output grid nodes, output mesh nodes. + """ + # Embedding - grid_nodes_emb = self.grid_nodes_mlp(input_grid_nodes) - mesh_nodes_emb = self.mesh_nodes_mlp(input_mesh_nodes) - edge_attr_emb = self.edge_attr_mlp(input_edge_attr) - - latent_mesh_nodes = mesh_nodes_emb + self.conv(x=(grid_nodes_emb, mesh_nodes_emb), - edge_index=edge_index, - edge_attr=edge_attr_emb) - latent_grid_nodes = grid_nodes_emb + self.grid_nodes_mlp_2(grid_nodes_emb) - # TODO: Ask why we need to update eg2m, since we don't use them later - return latent_mesh_nodes, latent_grid_nodes + grid_emb = self.grid_mlp(input_grid_nodes) + mesh_emb = self.mesh_mlp(input_mesh_nodes) + edges_emb = self.edges_mlp(input_edge_attr) + + # Message-passing + residual connection + latent_mesh_nodes = mesh_emb + self.gnn( + x=(grid_emb, mesh_emb), + edge_index=edge_index, + edge_attr=edges_emb, + ) + + # Update grid nodes + residual connection + latent_grid_nodes = grid_emb + self.grid_mlp_final(grid_emb) + + return latent_grid_nodes, latent_mesh_nodes diff --git a/graph_weather/models/gencast/layers/modules.py b/graph_weather/models/gencast/layers/modules.py index 4cb968ca..a3ebed97 100644 --- a/graph_weather/models/gencast/layers/modules.py +++ b/graph_weather/models/gencast/layers/modules.py @@ -73,7 +73,7 @@ class InteractionNetwork(MessagePassing): The underlying graph is a directed graph. Note: - We don't need to update edges in GenCast, hence we skip that. + We don't need to update edges in GenCast, hence we skip it. """ def __init__( diff --git a/graph_weather/models/gencast/utils/noise.py b/graph_weather/models/gencast/utils/noise.py index 89c4269d..99964fef 100644 --- a/graph_weather/models/gencast/utils/noise.py +++ b/graph_weather/models/gencast/utils/noise.py @@ -2,6 +2,7 @@ import numpy as np import pyshtools as pysh +import torch def generate_isotropic_noise(num_lat, num_samples=1): @@ -48,3 +49,35 @@ def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7): sigma_max ** (1 / rho) + u * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho)) ) ** rho return noise_level + + +class Preconditioner(torch.nn.Module): + """Collection of preconditioning functions. + + These functions are described in Karras (2022), table 1. + """ + + def __init__(self, sigma_data: float = 1): + """Initialize the preconditioning functions. + + Args: + sigma_data (float): Karras suggests 0.5, GenCast 1. Defaults to 1. + """ + super().__init__() + self.sigma_data = sigma_data + + def c_skip(self, sigma): + """Scaling factor for skip connection.""" + return self.sigma_data / (sigma**2 + self.sigma_data**2) + + def c_out(self, sigma): + """Scaling factor for output.""" + return sigma * self.sigma_data / torch.sqrt(sigma**2 + self.sigma_data**2) + + def c_in(self, sigma): + """Scaling factor for input.""" + return 1 / torch.sqrt(sigma**2 + self.sigma_data**2) + + def c_noise(self, sigma): + """Scaling factor for noise level.""" + return 1 / 4 * torch.log(sigma) diff --git a/train/gencast_demo.ipynb b/train/gencast_demo.ipynb index a970d52b..4e6b0855 100644 --- a/train/gencast_demo.ipynb +++ b/train/gencast_demo.ipynb @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [ { @@ -114,7 +114,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ @@ -484,29 +484,119 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Encoder and Decoder" + "## Denoiser" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set the dataset" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "from graph_weather.data.gencast_dataloader import GenCastDataset\n", + "\n", + "atmospheric_features = [\"geopotential\", \n", + " \"specific_humidity\",\n", + " \"temperature\",\n", + " \"u_component_of_wind\",\n", + " \"v_component_of_wind\",\n", + " \"vertical_velocity\"]\n", + " \n", + "single_features = [\"2m_temperature\", \n", + " \"10m_u_component_of_wind\",\n", + " \"10m_v_component_of_wind\",\n", + " \"mean_sea_level_pressure\",\n", + " #\"sea_surface_temperature\",\n", + " \"total_precipitation_12hr\"]\n", + "\n", + "static_features = [\"geopotential_at_surface\", \n", + " \"land_sea_mask\"]\n", + "\n", + "obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr'\n", + "#obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr'\n", + "#obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr'\n", + "\n", + "batch_size = 2\n", + "dataset = GenCastDataset(obs_path=obs_path,\n", + " atmospheric_features=atmospheric_features,\n", + " single_features=single_features,\n", + " static_features=static_features,\n", + " max_year=2018,\n", + " time_step=2)\n", + "\n", + "dataloader = DataLoader(dataset, batch_size=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/gbruno16/Code/graph_weather_dev\n" + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 25\u001b[0m\n\u001b[1;32m 5\u001b[0m denoiser \u001b[38;5;241m=\u001b[39m Denoiser(\n\u001b[1;32m 6\u001b[0m grid_lon\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lon,\n\u001b[1;32m 7\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lat,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m device\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 16\u001b[0m )\n\u001b[1;32m 18\u001b[0m loss \u001b[38;5;241m=\u001b[39m WeightedMSELoss(\n\u001b[1;32m 19\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mgrid_lat),\n\u001b[1;32m 20\u001b[0m pressure_levels\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mpressure_levels),\n\u001b[1;32m 21\u001b[0m num_atmospheric_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(dataset\u001b[38;5;241m.\u001b[39matmospheric_features),\n\u001b[1;32m 22\u001b[0m single_features_weights\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m1.0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m]),\n\u001b[1;32m 23\u001b[0m )\n\u001b[0;32m---> 25\u001b[0m prev_inputs, noise_level, corrupted_residuals, target_residuals \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43miter\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 629\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 674\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 675\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 677\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpossibly_batched_index\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/Code/graph_weather_dev/graph_weather/data/gencast_dataloader.py:177\u001b[0m, in \u001b[0;36mGenCastDataset.__getitem__\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 170\u001b[0m inputs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mnan_to_num(inputs)\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# Load target data\u001b[39;00m\n\u001b[1;32m 173\u001b[0m ds_target_atm \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 174\u001b[0m \u001b[43mds_target\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43matmospheric_features\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlongitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlatitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlevel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvariable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m--> 177\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n\u001b[1;32m 178\u001b[0m )\n\u001b[1;32m 179\u001b[0m ds_target_atm \u001b[38;5;241m=\u001b[39m einops\u001b[38;5;241m.\u001b[39mrearrange(ds_target_atm, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlon lat lev var -> lon lat (var lev)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 180\u001b[0m ds_target_single \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 181\u001b[0m ds_target[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msingle_features]\n\u001b[1;32m 182\u001b[0m \u001b[38;5;241m.\u001b[39mto_array()\n\u001b[1;32m 183\u001b[0m \u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlongitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlatitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvariable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;241m.\u001b[39mvalues\n\u001b[1;32m 185\u001b[0m )\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/dataarray.py:733\u001b[0m, in \u001b[0;36mDataArray.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 724\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 725\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m 726\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;124;03m The array's data as a numpy.ndarray.\u001b[39;00m\n\u001b[1;32m 728\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 731\u001b[0m \u001b[38;5;124;03m type does not support coercion like this (e.g. cupy).\u001b[39;00m\n\u001b[1;32m 732\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:614\u001b[0m, in \u001b[0;36mVariable.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 612\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 613\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"The variable's data as a numpy.ndarray\"\"\"\u001b[39;00m\n\u001b[0;32m--> 614\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_as_array_or_item\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:314\u001b[0m, in \u001b[0;36m_as_array_or_item\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_as_array_or_item\u001b[39m(data):\n\u001b[1;32m 301\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the given values as a numpy array, or as an individual item if\u001b[39;00m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;124;03m it's a 0d datetime64 or timedelta64 array.\u001b[39;00m\n\u001b[1;32m 303\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;124;03m TODO: remove this (replace with np.asarray) once these issues are fixed\u001b[39;00m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 314\u001b[0m data \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39masarray(data)\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/array/core.py:1693\u001b[0m, in \u001b[0;36mArray.__array__\u001b[0;34m(self, dtype, **kwargs)\u001b[0m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__array__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 1693\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mand\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m!=\u001b[39m dtype:\n\u001b[1;32m 1695\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mastype(dtype)\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:375\u001b[0m, in \u001b[0;36mDaskMethodsMixin.compute\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 352\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute this dask collection\u001b[39;00m\n\u001b[1;32m 353\u001b[0m \n\u001b[1;32m 354\u001b[0m \u001b[38;5;124;03m This turns a lazy Dask collection into its in-memory equivalent.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[38;5;124;03m dask.compute\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 375\u001b[0m (result,) \u001b[38;5;241m=\u001b[39m \u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraverse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:661\u001b[0m, in \u001b[0;36mcompute\u001b[0;34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m postcomputes\u001b[38;5;241m.\u001b[39mappend(x\u001b[38;5;241m.\u001b[39m__dask_postcompute__())\n\u001b[1;32m 660\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m shorten_traceback():\n\u001b[0;32m--> 661\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdsk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 663\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m repack([f(r, \u001b[38;5;241m*\u001b[39ma) \u001b[38;5;28;01mfor\u001b[39;00m r, (f, a) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(results, postcomputes)])\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/queue.py:171\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_qsize():\n\u001b[0;32m--> 171\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnot_empty\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m must be a non-negative number\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/threading.py:327\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 327\u001b[0m \u001b[43mwaiter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 328\u001b[0m gotit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ - "%cd ../" + "import torch\n", + "\n", + "from graph_weather.models.gencast import Denoiser, WeightedMSELoss\n", + "\n", + "denoiser = Denoiser(\n", + " grid_lon=dataset.grid_lon,\n", + " grid_lat=dataset.grid_lat,\n", + " input_features_dim=dataset.input_features_dim,\n", + " output_features_dim=dataset.output_features_dim,\n", + " hidden_dims=[16, 16],\n", + " num_blocks=3,\n", + " num_heads=4,\n", + " splits=0,\n", + " num_hops=1,\n", + " device=torch.device(\"cpu\"),\n", + ")\n", + "\n", + "loss = WeightedMSELoss(\n", + " grid_lat=torch.tensor(dataset.grid_lat),\n", + " pressure_levels=torch.tensor(dataset.pressure_levels),\n", + " num_atmospheric_features=len(dataset.atmospheric_features),\n", + " single_features_weights=torch.tensor([1.0, 0.1, 0.1, 0.1, 0.1]),\n", + ")\n", + "\n", + "prev_inputs, noise_level, corrupted_residuals, target_residuals = next(iter(dataloader))\n" ] }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -545,10 +635,277 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(tensor([[-0.4831, -0.8099, 0.9474, -0.3155, 1.3251, -3.0360, 2.3721],\n", + " [-0.8147, -0.4318, 0.9071, -0.1851, 1.5138, -3.1564, 2.1669],\n", + " [-0.6758, -0.5384, 0.9154, -0.3039, 1.4295, -3.0700, 2.2433],\n", + " [-0.7029, -0.4444, 0.9977, 0.0664, 1.4838, -3.3822, 1.9816],\n", + " [-0.2588, -0.9820, 0.9762, -0.3303, 1.1857, -2.9864, 2.3956]],\n", + " grad_fn=),\n", + " tensor([[ 1.0823, -0.1741, -0.1074, 0.1718, -2.1206, 1.2814, -0.1335],\n", + " [ 1.4093, -0.1385, 0.4016, -0.0235, -1.7778, -0.0566, 0.1855],\n", + " [ 0.2196, -0.2307, -0.0837, 0.8277, -1.7540, 1.4246, -0.4036]],\n", + " grad_fn=))" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from graph_weather.models.gencast.layers.encoder import Encoder\n", + "grid_dim=3\n", + "mesh_dim=4\n", + "edge_dim=5\n", + "hidden_dims=[6,7]\n", + "encoder = Encoder(grid_dim=grid_dim, mesh_dim=mesh_dim, edge_dim=edge_dim, hidden_dims=hidden_dims, activation_layer=torch.nn.SiLU, use_layer_norm=True)\n", + "grid_features = torch.rand((sender_nodes, grid_dim))\n", + "mesh_features = torch.rand((receiver_nodes, mesh_dim))\n", + "edge_features = torch.rand((len(edge_index[0]), edge_dim))\n", + "encoder.forward(grid_features, mesh_features, edge_features, edge_index)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[1, 2, 3, 4]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "[1,2,3]+[4]" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/home/gbruno16/Code/graph_weather_dev\n" + ] + } + ], + "source": [ + "%cd ../" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [], - "source": [] + "source": [ + "from torch_geometric.data import Batch\n", + "import numpy as np\n", + "\n", + "from graph_weather.models.gencast import GraphBuilder\n", + "\n", + "grid_lat = np.arange(-90, 90, 90)\n", + "grid_lon = np.arange(0, 360, 180)\n", + "\n", + "graphs = GraphBuilder(grid_lat=grid_lat, grid_lon=grid_lon, splits=0, num_hops=0)\n", + "\n", + "g = graphs.mesh_graph\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "HeteroDataBatch(\n", + " grid_nodes={\n", + " x=[12, 3],\n", + " batch=[12],\n", + " ptr=[4],\n", + " },\n", + " mesh_nodes={\n", + " x=[36, 3],\n", + " batch=[36],\n", + " ptr=[4],\n", + " },\n", + " (grid_nodes, to, mesh_nodes)={\n", + " edge_index=[2, 6],\n", + " edge_attr=[6, 4],\n", + " }\n", + ")\n", + "tensor([[ 2, 3, 6, 7, 10, 11],\n", + " [ 8, 5, 20, 17, 32, 29]])\n" + ] + } + ], + "source": [ + "gb = Batch.from_data_list([graphs.g2m_graph]*3)\n", + "\n", + "print(gb)\n", + "\n", + "print(gb[\"grid_nodes\",\"to\",\"mesh_nodes\"].edge_index)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'x': tensor([[ 1.8759e-01, 5.0000e-01, 8.6603e-01],\n", + " [ 7.9465e-01, -5.0000e-01, 8.6603e-01],\n", + " [ 7.9465e-01, 1.0000e+00, 0.0000e+00],\n", + " [ 1.8759e-01, 5.0000e-01, -8.6603e-01],\n", + " [-7.9465e-01, 5.0000e-01, 8.6603e-01],\n", + " [ 1.8759e-01, -1.0000e+00, -8.7423e-08],\n", + " [-1.8759e-01, -5.0000e-01, 8.6603e-01],\n", + " [ 7.9465e-01, -5.0000e-01, -8.6603e-01],\n", + " [-1.8759e-01, 1.0000e+00, 0.0000e+00],\n", + " [-1.8759e-01, -5.0000e-01, -8.6603e-01],\n", + " [-7.9465e-01, 5.0000e-01, -8.6603e-01],\n", + " [-7.9465e-01, -1.0000e+00, -8.7423e-08]])}" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "graphs.m2g_graph[\"mesh_nodes\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import torch\n", + "from graph_weather.models.gencast import Denoiser\n", + "\n", + "grid_lat = np.arange(-90, 90, 90)\n", + "grid_lon = np.arange(0, 360, 180)\n", + "batch_size = 9\n", + "input_features_dim=10\n", + "output_features_dim=5\n", + "\n", + "denoiser = Denoiser(grid_lon=grid_lon,\n", + " grid_lat=grid_lat,\n", + " input_features_dim=10,\n", + " output_features_dim=5,\n", + " hidden_dims=[4,8],\n", + " splits=0,\n", + " num_hops=0)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "torch.Size([9, 2, 2, 5])\n" + ] + } + ], + "source": [ + "grid_features = torch.rand((batch_size, len(grid_lon), len(grid_lat),2*input_features_dim))\n", + "print(denoiser.f_theta(grid_features, None).shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([9, 2, 2, 5])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import einops\n", + "einops.rearrange(\n", + " torch.rand([9,4,5]), \"b (lon lat) f -> b lon lat f\", lon=grid_features.shape[1]\n", + " ).shape" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "from torch_geometric.data import Batch\n", + "g2m_batched = Batch([graphs.g2m_graph] * 2)" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "HeteroDataBatch(\n", + " grid_nodes={\n", + " x=[8, 3],\n", + " batch=[8],\n", + " ptr=[3],\n", + " },\n", + " mesh_nodes={\n", + " x=[24, 3],\n", + " batch=[24],\n", + " ptr=[3],\n", + " },\n", + " (grid_nodes, to, mesh_nodes)={\n", + " edge_index=[2, 4],\n", + " edge_attr=[4, 4],\n", + " }\n", + ")" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "Batch.from_data_list([graphs.g2m_graph] * 2)" + ] } ], "metadata": { From aec58387c9533c3025cc4e5c35d34b6c8ab26af9 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Wed, 12 Jun 2024 12:11:16 +0200 Subject: [PATCH 3/7] Rebase --- graph_weather/data/gencast_dataloader.py | 20 +++++------ .../models/gencast/weighted_mse_loss.py | 9 +++-- tests/test_model.py | 4 +-- train/gencast_demo.ipynb | 35 +++---------------- 4 files changed, 23 insertions(+), 45 deletions(-) diff --git a/graph_weather/data/gencast_dataloader.py b/graph_weather/data/gencast_dataloader.py index 93e9e221..35bc77b4 100644 --- a/graph_weather/data/gencast_dataloader.py +++ b/graph_weather/data/gencast_dataloader.py @@ -167,7 +167,7 @@ def __getitem__(self, item): # Concatenate timesteps inputs = np.concatenate([inputs[0, :, :, :], inputs[1, :, :, :]], axis=-1) - inputs = np.nan_to_num(inputs).astype(np.float32) + prev_inputs = np.nan_to_num(inputs).astype(np.float32) # Load target data ds_target_atm = ( @@ -192,16 +192,16 @@ def __getitem__(self, item): target_residuals = np.nan_to_num(target_residuals).astype(np.float32) # Corrupt targets with noise - noise_level = np.array([sample_noise_level()]).astype(np.float32) + noise_levels = np.array([sample_noise_level()]).astype(np.float32) noise = generate_isotropic_noise( num_lat=self.num_lat, num_samples=target_residuals.shape[-1] ) - corrupted_residuals = target_residuals + noise_level * noise + corrupted_targets = target_residuals + noise_levels * noise return ( - inputs, - noise_level, - corrupted_residuals, + corrupted_targets, + prev_inputs, + noise_levels, target_residuals, ) @@ -372,7 +372,7 @@ def __getitem__(self, item): inputs = np.concatenate([batched_inputs_norm, ds_clock], axis=-1) # Concatenate timesteps inputs = np.concatenate([inputs[:, 0, :, :, :], inputs[:, 1, :, :, :]], axis=-1) - inputs = np.nan_to_num(inputs).astype(np.float32) + prev_inputs = np.nan_to_num(inputs).astype(np.float32) # Compute targets residuals raw_targets = np.concatenate([ds_atm, ds_single], axis=-1) @@ -382,13 +382,13 @@ def __getitem__(self, item): # Corrupt targets with noise noise_levels = np.zeros((self.batch_size, 1), dtype=np.float32) - corrupted_residuals = np.zeros_like(target_residuals, dtype=np.float32) + corrupted_targets = np.zeros_like(target_residuals, dtype=np.float32) for b in range(self.batch_size): noise_level = sample_noise_level() noise = generate_isotropic_noise( num_lat=self.num_lat, num_samples=target_residuals.shape[-1] ) - corrupted_residuals[b] = target_residuals[b] + noise_level * noise + corrupted_targets[b] = target_residuals[b] + noise_level * noise noise_levels[b] = noise_level - return (inputs, noise_levels, corrupted_residuals, target_residuals) + return (corrupted_targets, prev_inputs, noise_levels, target_residuals) diff --git a/graph_weather/models/gencast/weighted_mse_loss.py b/graph_weather/models/gencast/weighted_mse_loss.py index bded0c53..baa2b21c 100644 --- a/graph_weather/models/gencast/weighted_mse_loss.py +++ b/graph_weather/models/gencast/weighted_mse_loss.py @@ -68,15 +68,18 @@ def _lambda_sigma(self, noise_level): return noise_weights # [batch, 1] def forward( - self, pred: torch.Tensor, target: torch.Tensor, noise_level: torch.Tensor + self, + pred: torch.Tensor, + noise_level: torch.Tensor, + target: torch.Tensor, ) -> torch.Tensor: """Compute the loss. Args: pred (torch.Tensor): prediction of the model [batch, lon, lat, var]. - target (torch.Tensor): target tensor [batch, lon, lat, var]. noise_level (torch.Tensor): noise levels fed to the model for the corresponding predictions [batch, 1]. + target (torch.Tensor): target tensor [batch, lon, lat, var]. Returns: torch.Tensor: weighted MSE loss. @@ -84,7 +87,7 @@ def forward( # check shapes if not (pred.shape == target.shape): raise ValueError( - "redictions and targets must have same shape. The actual shapes " + "Predictions and targets must have same shape. The actual shapes " f"are {pred.shape} and {target.shape}." ) if not (len(pred.shape) == 4): diff --git a/tests/test_model.py b/tests/test_model.py index bc19cc79..5cae6fe1 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -325,5 +325,5 @@ def test_gencast_loss(): preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) noise_levels = torch.rand((batch_size, 1)) - targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) - assert loss.forward(preds, targets, noise_levels) is not None + targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) + assert loss.forward(preds, noise_levels, targets) is not None diff --git a/train/gencast_demo.ipynb b/train/gencast_demo.ipynb index 4e6b0855..2a005e30 100644 --- a/train/gencast_demo.ipynb +++ b/train/gencast_demo.ipynb @@ -77,7 +77,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -496,7 +496,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -538,34 +538,9 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 5, "metadata": {}, - "outputs": [ - { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 25\u001b[0m\n\u001b[1;32m 5\u001b[0m denoiser \u001b[38;5;241m=\u001b[39m Denoiser(\n\u001b[1;32m 6\u001b[0m grid_lon\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lon,\n\u001b[1;32m 7\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lat,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m device\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 16\u001b[0m )\n\u001b[1;32m 18\u001b[0m loss \u001b[38;5;241m=\u001b[39m WeightedMSELoss(\n\u001b[1;32m 19\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mgrid_lat),\n\u001b[1;32m 20\u001b[0m pressure_levels\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mpressure_levels),\n\u001b[1;32m 21\u001b[0m num_atmospheric_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(dataset\u001b[38;5;241m.\u001b[39matmospheric_features),\n\u001b[1;32m 22\u001b[0m single_features_weights\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m1.0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m]),\n\u001b[1;32m 23\u001b[0m )\n\u001b[0;32m---> 25\u001b[0m prev_inputs, noise_level, corrupted_residuals, target_residuals \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43miter\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 629\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 674\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 675\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 677\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpossibly_batched_index\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", - "File \u001b[0;32m~/Code/graph_weather_dev/graph_weather/data/gencast_dataloader.py:177\u001b[0m, in \u001b[0;36mGenCastDataset.__getitem__\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 170\u001b[0m inputs \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39mnan_to_num(inputs)\u001b[38;5;241m.\u001b[39mastype(np\u001b[38;5;241m.\u001b[39mfloat32)\n\u001b[1;32m 172\u001b[0m \u001b[38;5;66;03m# Load target data\u001b[39;00m\n\u001b[1;32m 173\u001b[0m ds_target_atm \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 174\u001b[0m \u001b[43mds_target\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43matmospheric_features\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 175\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 176\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlongitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlatitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlevel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvariable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m--> 177\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n\u001b[1;32m 178\u001b[0m )\n\u001b[1;32m 179\u001b[0m ds_target_atm \u001b[38;5;241m=\u001b[39m einops\u001b[38;5;241m.\u001b[39mrearrange(ds_target_atm, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlon lat lev var -> lon lat (var lev)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 180\u001b[0m ds_target_single \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 181\u001b[0m ds_target[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msingle_features]\n\u001b[1;32m 182\u001b[0m \u001b[38;5;241m.\u001b[39mto_array()\n\u001b[1;32m 183\u001b[0m \u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlongitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlatitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvariable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 184\u001b[0m \u001b[38;5;241m.\u001b[39mvalues\n\u001b[1;32m 185\u001b[0m )\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/dataarray.py:733\u001b[0m, in \u001b[0;36mDataArray.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 724\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 725\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m 726\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;124;03m The array's data as a numpy.ndarray.\u001b[39;00m\n\u001b[1;32m 728\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 731\u001b[0m \u001b[38;5;124;03m type does not support coercion like this (e.g. cupy).\u001b[39;00m\n\u001b[1;32m 732\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:614\u001b[0m, in \u001b[0;36mVariable.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 612\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 613\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"The variable's data as a numpy.ndarray\"\"\"\u001b[39;00m\n\u001b[0;32m--> 614\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_as_array_or_item\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:314\u001b[0m, in \u001b[0;36m_as_array_or_item\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_as_array_or_item\u001b[39m(data):\n\u001b[1;32m 301\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the given values as a numpy array, or as an individual item if\u001b[39;00m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;124;03m it's a 0d datetime64 or timedelta64 array.\u001b[39;00m\n\u001b[1;32m 303\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;124;03m TODO: remove this (replace with np.asarray) once these issues are fixed\u001b[39;00m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 314\u001b[0m data \u001b[38;5;241m=\u001b[39m np\u001b[38;5;241m.\u001b[39masarray(data)\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/array/core.py:1693\u001b[0m, in \u001b[0;36mArray.__array__\u001b[0;34m(self, dtype, **kwargs)\u001b[0m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__array__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 1693\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mand\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m!=\u001b[39m dtype:\n\u001b[1;32m 1695\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mastype(dtype)\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:375\u001b[0m, in \u001b[0;36mDaskMethodsMixin.compute\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 352\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute this dask collection\u001b[39;00m\n\u001b[1;32m 353\u001b[0m \n\u001b[1;32m 354\u001b[0m \u001b[38;5;124;03m This turns a lazy Dask collection into its in-memory equivalent.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[38;5;124;03m dask.compute\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 375\u001b[0m (result,) \u001b[38;5;241m=\u001b[39m \u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraverse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:661\u001b[0m, in \u001b[0;36mcompute\u001b[0;34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m postcomputes\u001b[38;5;241m.\u001b[39mappend(x\u001b[38;5;241m.\u001b[39m__dask_postcompute__())\n\u001b[1;32m 660\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m shorten_traceback():\n\u001b[0;32m--> 661\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdsk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 663\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m repack([f(r, \u001b[38;5;241m*\u001b[39ma) \u001b[38;5;28;01mfor\u001b[39;00m r, (f, a) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(results, postcomputes)])\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/queue.py:171\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_qsize():\n\u001b[0;32m--> 171\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnot_empty\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m must be a non-negative number\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/threading.py:327\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 327\u001b[0m \u001b[43mwaiter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 328\u001b[0m gotit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " - ] - } - ], + "outputs": [], "source": [ "import torch\n", "\n", @@ -591,7 +566,7 @@ " single_features_weights=torch.tensor([1.0, 0.1, 0.1, 0.1, 0.1]),\n", ")\n", "\n", - "prev_inputs, noise_level, corrupted_residuals, target_residuals = next(iter(dataloader))\n" + "corrupted_residuals, prev_inputs, noise_levels, target_residuals = next(iter(dataloader))\n" ] }, { From 87a805629430f213528712c4675ed111158fa3f7 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Wed, 12 Jun 2024 13:03:05 +0200 Subject: [PATCH 4/7] Add denoiser test --- graph_weather/data/gencast_dataloader.py | 2 +- graph_weather/models/gencast/denoiser.py | 10 ++++--- tests/test_model.py | 33 +++++++++++++++++++++++- train/gencast_demo.ipynb | 27 ++++++++++++++++++- 4 files changed, 65 insertions(+), 7 deletions(-) diff --git a/graph_weather/data/gencast_dataloader.py b/graph_weather/data/gencast_dataloader.py index 35bc77b4..460765f6 100644 --- a/graph_weather/data/gencast_dataloader.py +++ b/graph_weather/data/gencast_dataloader.py @@ -58,7 +58,7 @@ def __init__( self.output_features_dim = len(atmospheric_features) * len(self.pressure_levels) + len( single_features ) - self.input_features_dim = self.output_features_dim + len(static_features) + self.input_features_dim = self.output_features_dim + len(static_features) + 4 self.time_step = time_step # e.g. 12h steps correspond to time_step = 2 in a 6h dataset diff --git a/graph_weather/models/gencast/denoiser.py b/graph_weather/models/gencast/denoiser.py index 7038f118..482d70b7 100644 --- a/graph_weather/models/gencast/denoiser.py +++ b/graph_weather/models/gencast/denoiser.py @@ -62,7 +62,7 @@ def __init__( # Initialize Encoder self.encoder = Encoder( - grid_dim=self.graphs.grid_nodes_dim + 2 * input_features_dim, + grid_dim=output_features_dim + 2 * input_features_dim + self.graphs.grid_nodes_dim, mesh_dim=self.graphs.mesh_nodes_dim, edge_dim=self.graphs.g2m_edges_dim, hidden_dims=hidden_dims, @@ -89,9 +89,11 @@ def _check_shapes(self, corrupted_targets, prev_inputs, noise_levels): exp_noise_shape = (batch_size, 1) if not all( - corrupted_targets.shape == exp_targets_shape, - prev_inputs.shape == exp_inputs_shape, - noise_levels.shape == exp_noise_shape, + [ + corrupted_targets.shape == exp_targets_shape, + prev_inputs.shape == exp_inputs_shape, + noise_levels.shape == exp_noise_shape, + ] ): raise ValueError( "The shapes of the input tensors don't match with the initialization parameters: " diff --git a/tests/test_model.py b/tests/test_model.py index 5cae6fe1..d094fba7 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -18,7 +18,7 @@ generate_isotropic_noise, sample_noise_level, ) -from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss +from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss, Denoiser def test_encoder(): @@ -327,3 +327,34 @@ def test_gencast_loss(): noise_levels = torch.rand((batch_size, 1)) targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) assert loss.forward(preds, noise_levels, targets) is not None + +def test_gencast_denoiser(): + grid_lat = np.arange(-90, 90, 1) + grid_lon = np.arange(0, 360, 1) + input_features_dim = 10 + output_features_dim = 5 + batch_size = 3 + + denoiser = Denoiser( + grid_lon=grid_lon, + grid_lat=grid_lat, + input_features_dim=input_features_dim, + output_features_dim=output_features_dim, + hidden_dims=[16, 32], + num_blocks=3, + num_heads=4, + splits=0, + num_hops=1, + device=torch.device("cpu"), + ).eval() + + corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2*input_features_dim)) + noise_levels = torch.randn((batch_size, 1)) + + with torch.no_grad(): + preds = denoiser(corrupted_targets=corrupted_targets, + prev_inputs=prev_inputs, + noise_levels=noise_levels) + + assert not torch.isnan(preds).any() \ No newline at end of file diff --git a/train/gencast_demo.ipynb b/train/gencast_demo.ipynb index 2a005e30..84d20064 100644 --- a/train/gencast_demo.ipynb +++ b/train/gencast_demo.ipynb @@ -540,7 +540,32 @@ "cell_type": "code", "execution_count": 5, "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[5], line 25\u001b[0m\n\u001b[1;32m 5\u001b[0m denoiser \u001b[38;5;241m=\u001b[39m Denoiser(\n\u001b[1;32m 6\u001b[0m grid_lon\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lon,\n\u001b[1;32m 7\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lat,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m device\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 16\u001b[0m )\n\u001b[1;32m 18\u001b[0m loss \u001b[38;5;241m=\u001b[39m WeightedMSELoss(\n\u001b[1;32m 19\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mgrid_lat),\n\u001b[1;32m 20\u001b[0m pressure_levels\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mpressure_levels),\n\u001b[1;32m 21\u001b[0m num_atmospheric_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(dataset\u001b[38;5;241m.\u001b[39matmospheric_features),\n\u001b[1;32m 22\u001b[0m single_features_weights\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m1.0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m]),\n\u001b[1;32m 23\u001b[0m )\n\u001b[0;32m---> 25\u001b[0m prev_inputs, noise_level, corrupted_residuals, target_residuals \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43miter\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 629\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 674\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 675\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 677\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpossibly_batched_index\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", + "File \u001b[0;32m~/Code/graph_weather_dev/graph_weather/data/gencast_dataloader.py:142\u001b[0m, in \u001b[0;36mGenCastDataset.__getitem__\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 135\u001b[0m ds_target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39misel(time\u001b[38;5;241m=\u001b[39mitem \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_step)\n\u001b[1;32m 137\u001b[0m \u001b[38;5;66;03m# Load inputs data\u001b[39;00m\n\u001b[1;32m 138\u001b[0m ds_inputs_atm \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 139\u001b[0m \u001b[43mds_inputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43matmospheric_features\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 140\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtime\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlongitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlatitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlevel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvariable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m--> 142\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n\u001b[1;32m 143\u001b[0m )\n\u001b[1;32m 144\u001b[0m ds_inputs_atm \u001b[38;5;241m=\u001b[39m einops\u001b[38;5;241m.\u001b[39mrearrange(ds_inputs_atm, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt lon lat lev var -> t lon lat (var lev)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 145\u001b[0m ds_inputs_single \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 146\u001b[0m ds_inputs[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msingle_features]\n\u001b[1;32m 147\u001b[0m \u001b[38;5;241m.\u001b[39mto_array()\n\u001b[1;32m 148\u001b[0m \u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtime\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlongitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlatitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvariable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 149\u001b[0m \u001b[38;5;241m.\u001b[39mvalues\n\u001b[1;32m 150\u001b[0m )\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/dataarray.py:733\u001b[0m, in \u001b[0;36mDataArray.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 724\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 725\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m 726\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;124;03m The array's data as a numpy.ndarray.\u001b[39;00m\n\u001b[1;32m 728\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 731\u001b[0m \u001b[38;5;124;03m type does not support coercion like this (e.g. cupy).\u001b[39;00m\n\u001b[1;32m 732\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:614\u001b[0m, in \u001b[0;36mVariable.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 612\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 613\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"The variable's data as a numpy.ndarray\"\"\"\u001b[39;00m\n\u001b[0;32m--> 614\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_as_array_or_item\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m)\u001b[49m\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:314\u001b[0m, in \u001b[0;36m_as_array_or_item\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_as_array_or_item\u001b[39m(data):\n\u001b[1;32m 301\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the given values as a numpy array, or as an individual item if\u001b[39;00m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;124;03m it's a 0d datetime64 or timedelta64 array.\u001b[39;00m\n\u001b[1;32m 303\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;124;03m TODO: remove this (replace with np.asarray) once these issues are fixed\u001b[39;00m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 314\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/array/core.py:1693\u001b[0m, in \u001b[0;36mArray.__array__\u001b[0;34m(self, dtype, **kwargs)\u001b[0m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__array__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 1693\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mand\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m!=\u001b[39m dtype:\n\u001b[1;32m 1695\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mastype(dtype)\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:375\u001b[0m, in \u001b[0;36mDaskMethodsMixin.compute\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 352\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute this dask collection\u001b[39;00m\n\u001b[1;32m 353\u001b[0m \n\u001b[1;32m 354\u001b[0m \u001b[38;5;124;03m This turns a lazy Dask collection into its in-memory equivalent.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[38;5;124;03m dask.compute\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 375\u001b[0m (result,) \u001b[38;5;241m=\u001b[39m \u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraverse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:661\u001b[0m, in \u001b[0;36mcompute\u001b[0;34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m postcomputes\u001b[38;5;241m.\u001b[39mappend(x\u001b[38;5;241m.\u001b[39m__dask_postcompute__())\n\u001b[1;32m 660\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m shorten_traceback():\n\u001b[0;32m--> 661\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdsk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 663\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m repack([f(r, \u001b[38;5;241m*\u001b[39ma) \u001b[38;5;28;01mfor\u001b[39;00m r, (f, a) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(results, postcomputes)])\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/queue.py:171\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_qsize():\n\u001b[0;32m--> 171\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnot_empty\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m must be a non-negative number\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", + "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/threading.py:327\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 327\u001b[0m \u001b[43mwaiter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 328\u001b[0m gotit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + ] + } + ], "source": [ "import torch\n", "\n", From b9e7816c9586d6adf1be433a4f3459a968635024 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Fri, 14 Jun 2024 08:26:36 +0200 Subject: [PATCH 5/7] Add demo --- graph_weather/models/gencast/utils/noise.py | 2 +- train/gencast_demo.ipynb | 424 ++++---------------- 2 files changed, 68 insertions(+), 358 deletions(-) diff --git a/graph_weather/models/gencast/utils/noise.py b/graph_weather/models/gencast/utils/noise.py index 99964fef..e44f50b3 100644 --- a/graph_weather/models/gencast/utils/noise.py +++ b/graph_weather/models/gencast/utils/noise.py @@ -26,7 +26,7 @@ def generate_isotropic_noise(num_lat, num_samples=1): for i in range(num_samples): clm = pysh.SHCoeffs.from_random(power) grid[:, :, i] = clm.expand(grid="DH2", extend=False).to_array().transpose() - return grid + return grid.astype(np.float32) def sample_noise_level(sigma_min=0.02, sigma_max=88, rho=7): diff --git a/train/gencast_demo.ipynb b/train/gencast_demo.ipynb index 84d20064..d8fd6d20 100644 --- a/train/gencast_demo.ipynb +++ b/train/gencast_demo.ipynb @@ -496,7 +496,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -504,65 +504,66 @@ "\n", "from graph_weather.data.gencast_dataloader import GenCastDataset\n", "\n", - "atmospheric_features = [\"geopotential\", \n", - " \"specific_humidity\",\n", - " \"temperature\",\n", - " \"u_component_of_wind\",\n", - " \"v_component_of_wind\",\n", - " \"vertical_velocity\"]\n", - " \n", - "single_features = [\"2m_temperature\", \n", - " \"10m_u_component_of_wind\",\n", - " \"10m_v_component_of_wind\",\n", - " \"mean_sea_level_pressure\",\n", - " #\"sea_surface_temperature\",\n", - " \"total_precipitation_12hr\"]\n", + "atmospheric_features = [\n", + " \"geopotential\",\n", + " \"specific_humidity\",\n", + " \"temperature\",\n", + " \"u_component_of_wind\",\n", + " \"v_component_of_wind\",\n", + " \"vertical_velocity\",\n", + "]\n", + "single_features = [\n", + " \"2m_temperature\",\n", + " \"10m_u_component_of_wind\",\n", + " \"10m_v_component_of_wind\",\n", + " \"mean_sea_level_pressure\",\n", + " # \"sea_surface_temperature\",\n", + " \"total_precipitation_12hr\",\n", + "]\n", + "static_features = [\n", + " \"geopotential_at_surface\",\n", + " \"land_sea_mask\",\n", + "]\n", "\n", - "static_features = [\"geopotential_at_surface\", \n", - " \"land_sea_mask\"]\n", + "obs_path = \"gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr\"\n", + "# obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr'\n", + "# obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr'\n", "\n", - "obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr'\n", - "#obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-1440x721.zarr'\n", - "#obs_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-512x256_equiangular_conservative.zarr'\n", + "dataset = GenCastDataset(\n", + " obs_path=obs_path,\n", + " atmospheric_features=atmospheric_features,\n", + " single_features=single_features,\n", + " static_features=static_features,\n", + " max_year=2018,\n", + " time_step=2,\n", + ")\n", "\n", "batch_size = 2\n", - "dataset = GenCastDataset(obs_path=obs_path,\n", - " atmospheric_features=atmospheric_features,\n", - " single_features=single_features,\n", - " static_features=static_features,\n", - " max_year=2018,\n", - " time_step=2)\n", - "\n", - "dataloader = DataLoader(dataset, batch_size=1)" + "dataloader = DataLoader(dataset, batch_size=batch_size)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Predict" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": {}, "outputs": [ { - "ename": "KeyboardInterrupt", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[5], line 25\u001b[0m\n\u001b[1;32m 5\u001b[0m denoiser \u001b[38;5;241m=\u001b[39m Denoiser(\n\u001b[1;32m 6\u001b[0m grid_lon\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lon,\n\u001b[1;32m 7\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mdataset\u001b[38;5;241m.\u001b[39mgrid_lat,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 15\u001b[0m device\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mdevice(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mcpu\u001b[39m\u001b[38;5;124m\"\u001b[39m),\n\u001b[1;32m 16\u001b[0m )\n\u001b[1;32m 18\u001b[0m loss \u001b[38;5;241m=\u001b[39m WeightedMSELoss(\n\u001b[1;32m 19\u001b[0m grid_lat\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mgrid_lat),\n\u001b[1;32m 20\u001b[0m pressure_levels\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor(dataset\u001b[38;5;241m.\u001b[39mpressure_levels),\n\u001b[1;32m 21\u001b[0m num_atmospheric_features\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mlen\u001b[39m(dataset\u001b[38;5;241m.\u001b[39matmospheric_features),\n\u001b[1;32m 22\u001b[0m single_features_weights\u001b[38;5;241m=\u001b[39mtorch\u001b[38;5;241m.\u001b[39mtensor([\u001b[38;5;241m1.0\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m, \u001b[38;5;241m0.1\u001b[39m]),\n\u001b[1;32m 23\u001b[0m )\n\u001b[0;32m---> 25\u001b[0m prev_inputs, noise_level, corrupted_residuals, target_residuals \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mnext\u001b[39;49m\u001b[43m(\u001b[49m\u001b[38;5;28;43miter\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mdataloader\u001b[49m\u001b[43m)\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:631\u001b[0m, in \u001b[0;36m_BaseDataLoaderIter.__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 628\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_sampler_iter \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 629\u001b[0m \u001b[38;5;66;03m# TODO(https://github.com/pytorch/pytorch/issues/76750)\u001b[39;00m\n\u001b[1;32m 630\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_reset() \u001b[38;5;66;03m# type: ignore[call-arg]\u001b[39;00m\n\u001b[0;32m--> 631\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_next_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 632\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[1;32m 633\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_dataset_kind \u001b[38;5;241m==\u001b[39m _DatasetKind\u001b[38;5;241m.\u001b[39mIterable \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 634\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m \u001b[38;5;129;01mand\u001b[39;00m \\\n\u001b[1;32m 635\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_num_yielded \u001b[38;5;241m>\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_IterableDataset_len_called:\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/dataloader.py:675\u001b[0m, in \u001b[0;36m_SingleProcessDataLoaderIter._next_data\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 673\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_next_data\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 674\u001b[0m index \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_next_index() \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[0;32m--> 675\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_dataset_fetcher\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mfetch\u001b[49m\u001b[43m(\u001b[49m\u001b[43mindex\u001b[49m\u001b[43m)\u001b[49m \u001b[38;5;66;03m# may raise StopIteration\u001b[39;00m\n\u001b[1;32m 676\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory:\n\u001b[1;32m 677\u001b[0m data \u001b[38;5;241m=\u001b[39m _utils\u001b[38;5;241m.\u001b[39mpin_memory\u001b[38;5;241m.\u001b[39mpin_memory(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_pin_memory_device)\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m_MapDatasetFetcher.fetch\u001b[0;34m(self, possibly_batched_index)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mfor\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43midx\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;129;43;01min\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[43mpossibly_batched_index\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py:51\u001b[0m, in \u001b[0;36m\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m 49\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset\u001b[38;5;241m.\u001b[39m__getitems__(possibly_batched_index)\n\u001b[1;32m 50\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m---> 51\u001b[0m data \u001b[38;5;241m=\u001b[39m [\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[43midx\u001b[49m\u001b[43m]\u001b[49m \u001b[38;5;28;01mfor\u001b[39;00m idx \u001b[38;5;129;01min\u001b[39;00m possibly_batched_index]\n\u001b[1;32m 52\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 53\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdataset[possibly_batched_index]\n", - "File \u001b[0;32m~/Code/graph_weather_dev/graph_weather/data/gencast_dataloader.py:142\u001b[0m, in \u001b[0;36mGenCastDataset.__getitem__\u001b[0;34m(self, item)\u001b[0m\n\u001b[1;32m 135\u001b[0m ds_target \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdata\u001b[38;5;241m.\u001b[39misel(time\u001b[38;5;241m=\u001b[39mitem \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m2\u001b[39m \u001b[38;5;241m*\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mtime_step)\n\u001b[1;32m 137\u001b[0m \u001b[38;5;66;03m# Load inputs data\u001b[39;00m\n\u001b[1;32m 138\u001b[0m ds_inputs_atm \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 139\u001b[0m \u001b[43mds_inputs\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43matmospheric_features\u001b[49m\u001b[43m]\u001b[49m\n\u001b[1;32m 140\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mto_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 141\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtranspose\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mtime\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlongitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlatitude\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mlevel\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvariable\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m)\u001b[49m\n\u001b[0;32m--> 142\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n\u001b[1;32m 143\u001b[0m )\n\u001b[1;32m 144\u001b[0m ds_inputs_atm \u001b[38;5;241m=\u001b[39m einops\u001b[38;5;241m.\u001b[39mrearrange(ds_inputs_atm, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mt lon lat lev var -> t lon lat (var lev)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 145\u001b[0m ds_inputs_single \u001b[38;5;241m=\u001b[39m (\n\u001b[1;32m 146\u001b[0m ds_inputs[\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39msingle_features]\n\u001b[1;32m 147\u001b[0m \u001b[38;5;241m.\u001b[39mto_array()\n\u001b[1;32m 148\u001b[0m \u001b[38;5;241m.\u001b[39mtranspose(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtime\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlongitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlatitude\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mvariable\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[1;32m 149\u001b[0m \u001b[38;5;241m.\u001b[39mvalues\n\u001b[1;32m 150\u001b[0m )\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/dataarray.py:733\u001b[0m, in \u001b[0;36mDataArray.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 724\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 725\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m np\u001b[38;5;241m.\u001b[39mndarray:\n\u001b[1;32m 726\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[1;32m 727\u001b[0m \u001b[38;5;124;03m The array's data as a numpy.ndarray.\u001b[39;00m\n\u001b[1;32m 728\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 731\u001b[0m \u001b[38;5;124;03m type does not support coercion like this (e.g. cupy).\u001b[39;00m\n\u001b[1;32m 732\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 733\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mvalues\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:614\u001b[0m, in \u001b[0;36mVariable.values\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 611\u001b[0m \u001b[38;5;129m@property\u001b[39m\n\u001b[1;32m 612\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mvalues\u001b[39m(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 613\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"The variable's data as a numpy.ndarray\"\"\"\u001b[39;00m\n\u001b[0;32m--> 614\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43m_as_array_or_item\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_data\u001b[49m\u001b[43m)\u001b[49m\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/xarray/core/variable.py:314\u001b[0m, in \u001b[0;36m_as_array_or_item\u001b[0;34m(data)\u001b[0m\n\u001b[1;32m 300\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m_as_array_or_item\u001b[39m(data):\n\u001b[1;32m 301\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Return the given values as a numpy array, or as an individual item if\u001b[39;00m\n\u001b[1;32m 302\u001b[0m \u001b[38;5;124;03m it's a 0d datetime64 or timedelta64 array.\u001b[39;00m\n\u001b[1;32m 303\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 312\u001b[0m \u001b[38;5;124;03m TODO: remove this (replace with np.asarray) once these issues are fixed\u001b[39;00m\n\u001b[1;32m 313\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 314\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[43mnp\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43masarray\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 315\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mndim \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 316\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m data\u001b[38;5;241m.\u001b[39mdtype\u001b[38;5;241m.\u001b[39mkind \u001b[38;5;241m==\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mM\u001b[39m\u001b[38;5;124m\"\u001b[39m:\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/array/core.py:1693\u001b[0m, in \u001b[0;36mArray.__array__\u001b[0;34m(self, dtype, **kwargs)\u001b[0m\n\u001b[1;32m 1692\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__array__\u001b[39m(\u001b[38;5;28mself\u001b[39m, dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[0;32m-> 1693\u001b[0m x \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1694\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m dtype \u001b[38;5;129;01mand\u001b[39;00m x\u001b[38;5;241m.\u001b[39mdtype \u001b[38;5;241m!=\u001b[39m dtype:\n\u001b[1;32m 1695\u001b[0m x \u001b[38;5;241m=\u001b[39m x\u001b[38;5;241m.\u001b[39mastype(dtype)\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:375\u001b[0m, in \u001b[0;36mDaskMethodsMixin.compute\u001b[0;34m(self, **kwargs)\u001b[0m\n\u001b[1;32m 351\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute\u001b[39m(\u001b[38;5;28mself\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs):\n\u001b[1;32m 352\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"Compute this dask collection\u001b[39;00m\n\u001b[1;32m 353\u001b[0m \n\u001b[1;32m 354\u001b[0m \u001b[38;5;124;03m This turns a lazy Dask collection into its in-memory equivalent.\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 373\u001b[0m \u001b[38;5;124;03m dask.compute\u001b[39;00m\n\u001b[1;32m 374\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[0;32m--> 375\u001b[0m (result,) \u001b[38;5;241m=\u001b[39m \u001b[43mcompute\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mtraverse\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mFalse\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 376\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m result\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/site-packages/dask/base.py:661\u001b[0m, in \u001b[0;36mcompute\u001b[0;34m(traverse, optimize_graph, scheduler, get, *args, **kwargs)\u001b[0m\n\u001b[1;32m 658\u001b[0m postcomputes\u001b[38;5;241m.\u001b[39mappend(x\u001b[38;5;241m.\u001b[39m__dask_postcompute__())\n\u001b[1;32m 660\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m shorten_traceback():\n\u001b[0;32m--> 661\u001b[0m results \u001b[38;5;241m=\u001b[39m \u001b[43mschedule\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdsk\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkeys\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 663\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m repack([f(r, \u001b[38;5;241m*\u001b[39ma) \u001b[38;5;28;01mfor\u001b[39;00m r, (f, a) \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(results, postcomputes)])\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/queue.py:171\u001b[0m, in \u001b[0;36mQueue.get\u001b[0;34m(self, block, timeout)\u001b[0m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[1;32m 170\u001b[0m \u001b[38;5;28;01mwhile\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_qsize():\n\u001b[0;32m--> 171\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnot_empty\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwait\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 172\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m timeout \u001b[38;5;241m<\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[1;32m 173\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mValueError\u001b[39;00m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtimeout\u001b[39m\u001b[38;5;124m'\u001b[39m\u001b[38;5;124m must be a non-negative number\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n", - "File \u001b[0;32m~/.conda/envs/graph/lib/python3.11/threading.py:327\u001b[0m, in \u001b[0;36mCondition.wait\u001b[0;34m(self, timeout)\u001b[0m\n\u001b[1;32m 325\u001b[0m \u001b[38;5;28;01mtry\u001b[39;00m: \u001b[38;5;66;03m# restore state no matter what (e.g., KeyboardInterrupt)\u001b[39;00m\n\u001b[1;32m 326\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m timeout \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m--> 327\u001b[0m \u001b[43mwaiter\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43macquire\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 328\u001b[0m gotit \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mTrue\u001b[39;00m\n\u001b[1;32m 329\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n", - "\u001b[0;31mKeyboardInterrupt\u001b[0m: " + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration 0 with training loss 0.1087784692645073\n", + "Iteration 1 with training loss 0.10737644881010056\n", + "Iteration 2 with training loss 0.1024656742811203\n", + "Iteration 3 with training loss 0.10202587395906448\n", + "Iteration 4 with training loss 0.09606097638607025\n", + "Iteration 5 with training loss 0.09462107717990875\n" ] } ], @@ -584,328 +585,37 @@ " device=torch.device(\"cpu\"),\n", ")\n", "\n", - "loss = WeightedMSELoss(\n", + "criterion = WeightedMSELoss(\n", " grid_lat=torch.tensor(dataset.grid_lat),\n", " pressure_levels=torch.tensor(dataset.pressure_levels),\n", " num_atmospheric_features=len(dataset.atmospheric_features),\n", " single_features_weights=torch.tensor([1.0, 0.1, 0.1, 0.1, 0.1]),\n", ")\n", "\n", - "corrupted_residuals, prev_inputs, noise_levels, target_residuals = next(iter(dataloader))\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "\n", - "from graph_weather.models.gencast.layers.modules import MLP, InteractionNetwork\n", - "\n", - "nn = MLP(\n", - " input_dim=5, hidden_dims=[10, 5, 4, 2], activation_layer=torch.nn.GELU, use_layer_norm=True\n", - ")\n", - "\n", - "edge_index = torch.tensor(\n", - " [[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4], [0, 1, 2, 0, 1, 2, 0, 1, 1, 2, 0]], dtype=torch.long\n", - ")\n", - "\n", - "sender_nodes = 5\n", - "receiver_nodes = 3\n", - "sender_dim = 10\n", - "receiver_dim = 11\n", - "edge_attr_dim = 12\n", - "hidden_dims = [13, 14]\n", - "\n", - "sender_features = torch.rand((sender_nodes, sender_dim))\n", - "receiver_features = torch.rand((receiver_nodes, receiver_dim))\n", - "edge_attr = torch.rand((len(edge_index[0]), edge_attr_dim))\n", - "gnn = InteractionNetwork(\n", - " sender_dim=sender_dim,\n", - " receiver_dim=receiver_dim,\n", - " edge_attr_dim=edge_attr_dim,\n", - " hidden_dims=hidden_dims,\n", - ")\n", - "\n", - "assert gnn((sender_features, receiver_features), edge_index, edge_attr).shape == (receiver_nodes,\n", - " hidden_dims[-1])" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(tensor([[-0.4831, -0.8099, 0.9474, -0.3155, 1.3251, -3.0360, 2.3721],\n", - " [-0.8147, -0.4318, 0.9071, -0.1851, 1.5138, -3.1564, 2.1669],\n", - " [-0.6758, -0.5384, 0.9154, -0.3039, 1.4295, -3.0700, 2.2433],\n", - " [-0.7029, -0.4444, 0.9977, 0.0664, 1.4838, -3.3822, 1.9816],\n", - " [-0.2588, -0.9820, 0.9762, -0.3303, 1.1857, -2.9864, 2.3956]],\n", - " grad_fn=),\n", - " tensor([[ 1.0823, -0.1741, -0.1074, 0.1718, -2.1206, 1.2814, -0.1335],\n", - " [ 1.4093, -0.1385, 0.4016, -0.0235, -1.7778, -0.0566, 0.1855],\n", - " [ 0.2196, -0.2307, -0.0837, 0.8277, -1.7540, 1.4246, -0.4036]],\n", - " grad_fn=))" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "from graph_weather.models.gencast.layers.encoder import Encoder\n", - "grid_dim=3\n", - "mesh_dim=4\n", - "edge_dim=5\n", - "hidden_dims=[6,7]\n", - "encoder = Encoder(grid_dim=grid_dim, mesh_dim=mesh_dim, edge_dim=edge_dim, hidden_dims=hidden_dims, activation_layer=torch.nn.SiLU, use_layer_norm=True)\n", - "grid_features = torch.rand((sender_nodes, grid_dim))\n", - "mesh_features = torch.rand((receiver_nodes, mesh_dim))\n", - "edge_features = torch.rand((len(edge_index[0]), edge_dim))\n", - "encoder.forward(grid_features, mesh_features, edge_features, edge_index)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[1, 2, 3, 4]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "[1,2,3]+[4]" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "/home/gbruno16/Code/graph_weather_dev\n" - ] - } - ], - "source": [ - "%cd ../" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "from torch_geometric.data import Batch\n", - "import numpy as np\n", - "\n", - "from graph_weather.models.gencast import GraphBuilder\n", - "\n", - "grid_lat = np.arange(-90, 90, 90)\n", - "grid_lon = np.arange(0, 360, 180)\n", + "optimizer = torch.optim.AdamW(denoiser.parameters(), lr=1e-3)\n", "\n", - "graphs = GraphBuilder(grid_lat=grid_lat, grid_lon=grid_lon, splits=0, num_hops=0)\n", + "for i, data in enumerate(dataloader):\n", + " corrupted_targets, prev_inputs, noise_levels, target_residuals = data\n", + " denoiser.zero_grad()\n", + " preds = denoiser(\n", + " corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels\n", + " )\n", + " loss = criterion(preds, noise_levels, target_residuals)\n", + " loss.backward()\n", "\n", - "g = graphs.mesh_graph\n" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "HeteroDataBatch(\n", - " grid_nodes={\n", - " x=[12, 3],\n", - " batch=[12],\n", - " ptr=[4],\n", - " },\n", - " mesh_nodes={\n", - " x=[36, 3],\n", - " batch=[36],\n", - " ptr=[4],\n", - " },\n", - " (grid_nodes, to, mesh_nodes)={\n", - " edge_index=[2, 6],\n", - " edge_attr=[6, 4],\n", - " }\n", - ")\n", - "tensor([[ 2, 3, 6, 7, 10, 11],\n", - " [ 8, 5, 20, 17, 32, 29]])\n" - ] - } - ], - "source": [ - "gb = Batch.from_data_list([graphs.g2m_graph]*3)\n", - "\n", - "print(gb)\n", - "\n", - "print(gb[\"grid_nodes\",\"to\",\"mesh_nodes\"].edge_index)" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'x': tensor([[ 1.8759e-01, 5.0000e-01, 8.6603e-01],\n", - " [ 7.9465e-01, -5.0000e-01, 8.6603e-01],\n", - " [ 7.9465e-01, 1.0000e+00, 0.0000e+00],\n", - " [ 1.8759e-01, 5.0000e-01, -8.6603e-01],\n", - " [-7.9465e-01, 5.0000e-01, 8.6603e-01],\n", - " [ 1.8759e-01, -1.0000e+00, -8.7423e-08],\n", - " [-1.8759e-01, -5.0000e-01, 8.6603e-01],\n", - " [ 7.9465e-01, -5.0000e-01, -8.6603e-01],\n", - " [-1.8759e-01, 1.0000e+00, 0.0000e+00],\n", - " [-1.8759e-01, -5.0000e-01, -8.6603e-01],\n", - " [-7.9465e-01, 5.0000e-01, -8.6603e-01],\n", - " [-7.9465e-01, -1.0000e+00, -8.7423e-08]])}" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "graphs.m2g_graph[\"mesh_nodes\"]" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import torch\n", - "from graph_weather.models.gencast import Denoiser\n", - "\n", - "grid_lat = np.arange(-90, 90, 90)\n", - "grid_lon = np.arange(0, 360, 180)\n", - "batch_size = 9\n", - "input_features_dim=10\n", - "output_features_dim=5\n", - "\n", - "denoiser = Denoiser(grid_lon=grid_lon,\n", - " grid_lat=grid_lat,\n", - " input_features_dim=10,\n", - " output_features_dim=5,\n", - " hidden_dims=[4,8],\n", - " splits=0,\n", - " num_hops=0)\n" + " optimizer.step()\n", + " print(f\"Iteration {i} with training loss {float(loss)}.\")\n", + " if i==5:\n", + " break\n", + " " ] }, { "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "torch.Size([9, 2, 2, 5])\n" - ] - } - ], - "source": [ - "grid_features = torch.rand((batch_size, len(grid_lon), len(grid_lat),2*input_features_dim))\n", - "print(denoiser.f_theta(grid_features, None).shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "torch.Size([9, 2, 2, 5])" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "import einops\n", - "einops.rearrange(\n", - " torch.rand([9,4,5]), \"b (lon lat) f -> b lon lat f\", lon=grid_features.shape[1]\n", - " ).shape" - ] - }, - { - "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": {}, "outputs": [], - "source": [ - "from torch_geometric.data import Batch\n", - "g2m_batched = Batch([graphs.g2m_graph] * 2)" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "HeteroDataBatch(\n", - " grid_nodes={\n", - " x=[8, 3],\n", - " batch=[8],\n", - " ptr=[3],\n", - " },\n", - " mesh_nodes={\n", - " x=[24, 3],\n", - " batch=[24],\n", - " ptr=[3],\n", - " },\n", - " (grid_nodes, to, mesh_nodes)={\n", - " edge_index=[2, 4],\n", - " edge_attr=[4, 4],\n", - " }\n", - ")" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "Batch.from_data_list([graphs.g2m_graph] * 2)" - ] + "source": [] } ], "metadata": { From d3650d74a534f9df499f8a1970a8072ae907e605 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 26 Jun 2024 08:44:35 +0000 Subject: [PATCH 6/7] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_model.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/tests/test_model.py b/tests/test_model.py index d094fba7..1302f9d2 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -325,9 +325,10 @@ def test_gencast_loss(): preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) noise_levels = torch.rand((batch_size, 1)) - targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) + targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim)) assert loss.forward(preds, noise_levels, targets) is not None + def test_gencast_denoiser(): grid_lat = np.arange(-90, 90, 1) grid_lon = np.arange(0, 360, 1) @@ -349,12 +350,12 @@ def test_gencast_denoiser(): ).eval() corrupted_targets = torch.randn((batch_size, len(grid_lon), len(grid_lat), output_features_dim)) - prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2*input_features_dim)) + prev_inputs = torch.randn((batch_size, len(grid_lon), len(grid_lat), 2 * input_features_dim)) noise_levels = torch.randn((batch_size, 1)) with torch.no_grad(): - preds = denoiser(corrupted_targets=corrupted_targets, - prev_inputs=prev_inputs, - noise_levels=noise_levels) - - assert not torch.isnan(preds).any() \ No newline at end of file + preds = denoiser( + corrupted_targets=corrupted_targets, prev_inputs=prev_inputs, noise_levels=noise_levels + ) + + assert not torch.isnan(preds).any() From 3fa681fe99c5d2ca03b0424f828abd4f85567047 Mon Sep 17 00:00:00 2001 From: gbruno16 Date: Thu, 27 Jun 2024 11:32:46 +0200 Subject: [PATCH 7/7] Add denoiser description --- graph_weather/models/gencast/denoiser.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/graph_weather/models/gencast/denoiser.py b/graph_weather/models/gencast/denoiser.py index 482d70b7..a70700b9 100644 --- a/graph_weather/models/gencast/denoiser.py +++ b/graph_weather/models/gencast/denoiser.py @@ -1,4 +1,11 @@ -"""Denoiser.""" +"""Denoiser. + +The denoiser takes as inputs the previous two timesteps, the corrupted target residual, and the +noise level, and outputs the denoised predictions. It performs the following tasks: +1. Initializes the graph, encoder, processor, and decoder. +2. Computes f_theta as the combination of encoder, processor, and decoder. +3. Preconditions f_theta on the noise levels using the parametrization from Karras et al. (2022). +""" import einops import numpy as np