Skip to content

Commit

Permalink
Add loss test and exceptions
Browse files Browse the repository at this point in the history
  • Loading branch information
gbruno16 committed Jun 9, 2024
1 parent 34600fe commit 6a944ad
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 24 deletions.
4 changes: 4 additions & 0 deletions graph_weather/models/gencast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
"""Main import for GenCast"""

from .graph.graph_builder import GraphBuilder
from .weighted_mse_loss import WeightedMSELoss
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ def __init__(
):
"""Initialize the WeightedMSELoss Module.
More details about the features weights are reported in GraphCast's paper.
In short, if the single features are "2m_temperature", "10m_u_component_of_wind",
"10m_v_component_of_wind", "mean_sea_level_pressure" and "total_precipitation_12hr",
then it's suggested to set corresponding weights as 1, 0.1, 0.1, 0.1 and 0.1.
More details about the features weights are reported in GraphCast's paper. In short, if the
single features are "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind",
"mean_sea_level_pressure" and "total_precipitation_12hr", then it's suggested to set
corresponding weights as 1, 0.1, 0.1, 0.1 and 0.1.
Args:
grid_lat (torch.Tensor, optional): 1D tensor containing all the latitudes.
pressure_levels (torch.Tensor, optional): 1D tensor containing all the pressure
levels per variable.
pressure_levels (torch.Tensor, optional): 1D tensor containing all the pressure levels
per variable.
num_atmospheric_features (int, optional): number of atmospheric features.
single_features_weights (torch.Tensor, optional): 1D tensor containing single
features weights.
single_features_weights (torch.Tensor, optional): 1D tensor containing single features
weights.
"""
super().__init__()

Expand All @@ -57,8 +57,8 @@ def __init__(
or single_features_weights is not None
):
raise ValueError(
"""Please to use features weights provide all three: pressure_levels,
num_atmospheric_features and single_features_weights."""
"Please to use features weights provide all three: pressure_levels,"
"num_atmospheric_features and single_features_weights."
)

self.sigma_data = 1 # assuming normalized data!
Expand All @@ -75,28 +75,55 @@ def forward(
Args:
pred (torch.Tensor): prediction of the model [batch, lon, lat, var].
target (torch.Tensor): target tensor [batch, lon, lat, var].
noise_level (torch.Tensor): noise levels fed to the model for the
corresponding predictions [batch, 1]
noise_level (torch.Tensor): noise levels fed to the model for the corresponding
predictions [batch, 1].
Returns:
torch.Tensor: weighted MSE loss.
"""
# check shapes
if not (pred.shape == target.shape):
raise ValueError(
"redictions and targets must have same shape. The actual shapes "
f"are {pred.shape} and {target.shape}."
)
if not (len(pred.shape) == 4):
raise ValueError(
"The expected shape for predictions and targets is "
f"[batch, lon, lat, var], but got {pred.shape}."
)
if not (noise_level.shape == (pred.shape[0], 1)):
raise ValueError(
f"The expected shape for noise levels is [batch, 1], but got {noise_level.shape}."
)

# compute square residuals
loss = (pred - target) ** 2 # [batch, lon, lat, var]
if torch.isnan(loss).any():
raise ValueError("NaN values encountered in loss calculation.")

# apply weight residuals
# apply area and features weights to residuals
if self.area_weights is not None:
if not (len(self.area_weights) == pred.shape[2]):
raise ValueError(
f"The size of grid_lat at initialization ({len(self.area_weights)}) "
f"and the number of latitudes in predictions ({pred.shape[2]}) "
"don't match."
)
loss *= self.area_weights[None, None, :, None]

if self.features_weights is not None:
loss *= self.feature_weights[None, None, None, :]

# compute mean across lon, lat, var for each sample in the batch
if not (len(self.features_weights) == pred.shape[-1]):
raise ValueError(
f"The size of features weights at initialization ({len(self.features_weights)})"
f" and the number of features in predictions ({pred.shape[-1]}) "
"don't match."
)
loss *= self.features_weights[None, None, None, :]

# compute means across lon, lat, var for each sample in the batch
loss = loss.flatten(1).mean(-1) # [batch]

# weight each sample using the corresponding noise level, then return the mean.
loss = (self._lambda_sigma(noise_level) * loss[:, None]).mean()

return loss
loss *= self._lambda_sigma(noise_level).flatten()
return loss.mean()
26 changes: 25 additions & 1 deletion tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
MetaModel,
)
from graph_weather.models.losses import NormalizedMSELoss

from graph_weather.models.gencast.utils.noise import (
generate_isotropic_noise,
sample_noise_level,
)
from graph_weather.models.gencast.graph.graph_builder import GraphBuilder
from graph_weather.models.gencast import GraphBuilder, WeightedMSELoss

def test_encoder():
lat_lons = []
Expand Down Expand Up @@ -266,3 +267,26 @@ def test_gencast_graph():
assert not torch.isnan(graphs.mesh_graph.edge_attr).any()
assert graphs.khop_mesh_graph.x.shape[0] == 12
assert graphs.khop_mesh_graph.edge_attr.shape[0] == 12*10

def test_gencast_loss():
grid_lat = torch.arange(-90, 90, 1)
grid_lon = torch.arange(0, 360, 1)
pressure_levels = torch.tensor(
[50., 100., 150., 200., 250, 300, 400, 500, 600, 700, 850, 925, 1000.]
)
single_features_weights = torch.tensor([1, 0.1, 0.1, 0.1, 0.1])
num_atmospheric_features = 6
batch_size = 3
features_dim = len(pressure_levels)*num_atmospheric_features + len(single_features_weights)

loss = WeightedMSELoss(
grid_lat=grid_lat,
pressure_levels=pressure_levels,
num_atmospheric_features=num_atmospheric_features,
single_features_weights=single_features_weights
)

preds = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim))
noise_levels = torch.rand((batch_size, 1))
targets = torch.rand((batch_size, len(grid_lon), len(grid_lat), features_dim))
assert loss.forward(preds, targets, noise_levels) is not None
8 changes: 4 additions & 4 deletions train/gencast_demo.ipynb

Large diffs are not rendered by default.

0 comments on commit 6a944ad

Please sign in to comment.