From c7eddb2fdcbe4678389f7a5f2e94bfe7feb54366 Mon Sep 17 00:00:00 2001 From: joeloskarsson Date: Tue, 12 Nov 2024 16:06:28 +0100 Subject: [PATCH] Work on fixing plotting functionality --- neural_lam/data_config.yaml | 2 +- neural_lam/models/ar_model.py | 10 ++++++++-- neural_lam/utils.py | 10 ++++++++++ neural_lam/vis.py | 15 ++++++++------- tests/test_mllam_dataset.py | 25 +++++++++++++++++++++---- 5 files changed, 48 insertions(+), 14 deletions(-) diff --git a/neural_lam/data_config.yaml b/neural_lam/data_config.yaml index f1527849..81c8a811 100644 --- a/neural_lam/data_config.yaml +++ b/neural_lam/data_config.yaml @@ -55,7 +55,7 @@ dataset: - z_isobaricInhPa_1000_instant - z_isobaricInhPa_500_instant num_forcing_features: 16 -grid_shape_state: [268, 238] +grid_shape_state: [248, 218] projection: class: LambertConformal kwargs: diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index a255e95b..25218e6f 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -391,6 +391,7 @@ def plot_examples(self, batch, n_examples, prediction=None): title=f"{var_name} ({var_unit}), " f"t={t_i} ({self.step_length * t_i} h)", vrange=var_vrange, + grid_limits=self.grid_limits ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -442,7 +443,8 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name): """ log_dict = {} metric_fig = vis.plot_error_map( - metric_tensor, self.config_loader, step_length=self.step_length + metric_tensor, self.config_loader, step_length=self.step_length, + grid_limits=self.grid_limits ) full_log_name = f"{prefix}_{metric_name}" log_dict[full_log_name] = wandb.Image(metric_fig) @@ -532,6 +534,8 @@ def on_test_epoch_end(self): loss_map, self.config_loader, title=f"Test loss, t={t_i} ({self.step_length * t_i} h)", + + grid_limits=self.grid_limits ) for t_i, loss_map in zip( self.args.val_steps_to_log, mean_spatial_loss @@ -544,7 +548,9 @@ def on_test_epoch_end(self): # also make without title and save as pdf pdf_loss_map_figs = [ - vis.plot_spatial_error(loss_map, self.config_loader) + vis.plot_spatial_error(loss_map, self.config_loader, + grid_limits=self.grid_limits + ) for loss_map in mean_spatial_loss ] pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps") diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 6f90c0a5..064d149b 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -78,6 +78,15 @@ def loads_file(fn): device=device, ) # (d_f,) + raw_coords = np.load(os.path.join(static_dir_path, "nwp_xy.npy")) + interior_coords = raw_coords.reshape(2, -1)[:,interior_mask.numpy()] + grid_limits = [ + interior_coords[0].min(), + interior_coords[0].max(), + interior_coords[1].min(), + interior_coords[1].max(), + ] + return { "boundary_mask": boundary_mask, "interior_mask": interior_mask, @@ -88,6 +97,7 @@ def loads_file(fn): "data_mean": data_mean, "data_std": data_std, "param_weights": param_weights, + "grid_limits": grid_limits, } diff --git a/neural_lam/vis.py b/neural_lam/vis.py index 5a43835d..fc2ddaae 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -71,6 +71,7 @@ def plot_on_axis( vmax=None, ax_title=None, cmap="plasma", + grid_limits=None ): """ Plot weather state on given axis @@ -84,9 +85,8 @@ def plot_on_axis( mask_reshaped.clamp(0.7, 1).cpu().numpy() ) # Faded border region - ax.set_global() ax.coastlines() # Add coastline outlines - data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy().T + data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy() im = ax.imshow( data_grid, origin="lower", @@ -94,7 +94,8 @@ def plot_on_axis( vmin=vmin, vmax=vmax, cmap=cmap, - ) # TODO Do we not need extent and transform arguments here? + extent=grid_limits + ) if ax_title: ax.set_title(ax_title, size=15) @@ -103,7 +104,7 @@ def plot_on_axis( @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, target, data_config, obs_mask=None, title=None, vrange=None + pred, target, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None ): """ Plot example prediction and grond truth. @@ -125,7 +126,7 @@ def plot_prediction( # Plot pred and target for ax, data in zip(axes, (target, pred)): - im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax) + im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits) # Ticks and labels axes[0].set_title("Ground Truth", size=15) @@ -141,7 +142,7 @@ def plot_prediction( @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_spatial_error( - error, data_config, obs_mask=None, title=None, vrange=None + error, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None ): """ Plot errors over spatial map @@ -159,7 +160,7 @@ def plot_spatial_error( subplot_kw={"projection": data_config.coords_projection}, ) - im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd") + im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd", grid_limits=grid_limits) # Ticks and labels cbar = fig.colorbar(im, aspect=30) diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py index d268e87f..af5ba0b6 100644 --- a/tests/test_mllam_dataset.py +++ b/tests/test_mllam_dataset.py @@ -5,6 +5,7 @@ # Third-party import pooch import pytest +import matplotlib.pyplot as plt # First-party from neural_lam.build_graph import main as build_graph @@ -12,6 +13,7 @@ from neural_lam.train_model import main as train_model from neural_lam.utils import load_static_data from neural_lam.weather_dataset import WeatherDataset +from neural_lam.vis import plot_prediction # Disable weights and biases to avoid unnecessary logging # and to avoid having to deal with authentication @@ -67,7 +69,9 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): n_prediction_timesteps = dataset.sample_length - n_input_steps static_data = load_static_data(dataset_name) - n_grid = static_data["interior_mask"].sum().item() + nx, ny = config.values["grid_shape_state"] + n_grid = nx * ny + static_data["interior_mask"].sum().item() n_boundary = static_data["boundary_mask"].sum().item() # check that the dataset is not empty @@ -107,13 +111,14 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath): } # check the sizes of the props - # TODO Should this config not be for only interior? - nx, ny = config.values["grid_shape_state"] - assert n_grid + n_boundary == nx * ny assert static_data["grid_static_features"].shape == ( n_grid, n_grid_static_features, ) + assert static_data["boundary_static_features"].shape == ( + n_boundary, + n_grid_static_features, # TODO Adjust dimensionality + ) assert static_data["step_diff_mean"].shape == (n_state_features,) assert static_data["step_diff_std"].shape == (n_state_features,) assert static_data["data_mean"].shape == (n_state_features,) @@ -150,3 +155,15 @@ def test_train_model_reduced_meps_dataset(): "--n_example_pred=0", ] train_model(args) + + +def test_vis_reduced_meps_dataset(meps_example_reduced_filepath): + data_config_file = meps_example_reduced_filepath / "data_config.yaml" + dataset_name = meps_example_reduced_filepath.name + + config = Config.from_file(str(data_config_file)) + + static_data = load_static_data(dataset_name) + geopotential = static_data["grid_static_features"][..., 2] + + plot_prediction(geopotential, geopotential, config, grid_limits=static_data["grid_limits"])