Skip to content

Commit

Permalink
Test fix related to plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 13, 2024
1 parent c7eddb2 commit aac1ff3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 16 deletions.
29 changes: 17 additions & 12 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,15 @@ def __init__(self, args):
static_data_dict = utils.load_static_data(
self.config_loader.dataset.name
)
for static_data_name, static_data_tensor in static_data_dict.items():
self.register_buffer(
static_data_name, static_data_tensor, persistent=False
)

for static_data_name, static_data in static_data_dict.items():
if isinstance(static_data, torch.Tensor):
self.register_buffer(
static_data_name, static_data, persistent=False
)
else:
# Non-tensor static can not and should not be buffers
setattr(self, static_data_name, static_data)

# Double grid output dim. to also output std.-dev.
self.output_std = bool(args.output_std)
Expand Down Expand Up @@ -391,7 +396,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
grid_limits=self.grid_limits
grid_limits=self.grid_limits,
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
Expand Down Expand Up @@ -443,8 +448,9 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
log_dict = {}
metric_fig = vis.plot_error_map(
metric_tensor, self.config_loader, step_length=self.step_length,
grid_limits=self.grid_limits
metric_tensor,
self.config_loader,
step_length=self.step_length,
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
Expand Down Expand Up @@ -534,8 +540,7 @@ def on_test_epoch_end(self):
loss_map,
self.config_loader,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",

grid_limits=self.grid_limits
grid_limits=self.grid_limits,
)
for t_i, loss_map in zip(
self.args.val_steps_to_log, mean_spatial_loss
Expand All @@ -548,9 +553,9 @@ def on_test_epoch_end(self):

# also make without title and save as pdf
pdf_loss_map_figs = [
vis.plot_spatial_error(loss_map, self.config_loader,
grid_limits=self.grid_limits
)
vis.plot_spatial_error(
loss_map, self.config_loader, grid_limits=self.grid_limits
)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
Expand Down
2 changes: 1 addition & 1 deletion neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def loads_file(fn):
) # (d_f,)

raw_coords = np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
interior_coords = raw_coords.reshape(2, -1)[:,interior_mask.numpy()]
interior_coords = raw_coords.reshape(2, -1)[:, interior_mask.numpy()]
grid_limits = [
interior_coords[0].min(),
interior_coords[0].max(),
Expand Down
12 changes: 9 additions & 3 deletions tests/test_mllam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
# Third-party
import pooch
import pytest
import matplotlib.pyplot as plt

# First-party
from neural_lam.build_graph import main as build_graph
from neural_lam.config import Config
from neural_lam.train_model import main as train_model
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
from neural_lam.vis import plot_prediction
from neural_lam.weather_dataset import WeatherDataset

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
Expand Down Expand Up @@ -108,6 +107,7 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
"data_mean",
"data_std",
"param_weights",
"grid_limits",
}

# check the sizes of the props
Expand All @@ -124,6 +124,7 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
assert static_data["data_mean"].shape == (n_state_features,)
assert static_data["data_std"].shape == (n_state_features,)
assert static_data["param_weights"].shape == (n_state_features,)
assert len(static_data["grid_limits"]) == 4

assert set(static_data.keys()) == required_props

Expand Down Expand Up @@ -166,4 +167,9 @@ def test_vis_reduced_meps_dataset(meps_example_reduced_filepath):
static_data = load_static_data(dataset_name)
geopotential = static_data["grid_static_features"][..., 2]

plot_prediction(geopotential, geopotential, config, grid_limits=static_data["grid_limits"])
plot_prediction(
geopotential,
geopotential,
config,
grid_limits=static_data["grid_limits"],
)

0 comments on commit aac1ff3

Please sign in to comment.