Skip to content

Commit

Permalink
Work on fixing plotting functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
joeloskarsson committed Nov 12, 2024
1 parent 22bfe65 commit c7eddb2
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 14 deletions.
2 changes: 1 addition & 1 deletion neural_lam/data_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ dataset:
- z_isobaricInhPa_1000_instant
- z_isobaricInhPa_500_instant
num_forcing_features: 16
grid_shape_state: [268, 238]
grid_shape_state: [248, 218]
projection:
class: LambertConformal
kwargs:
Expand Down
10 changes: 8 additions & 2 deletions neural_lam/models/ar_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def plot_examples(self, batch, n_examples, prediction=None):
title=f"{var_name} ({var_unit}), "
f"t={t_i} ({self.step_length * t_i} h)",
vrange=var_vrange,
grid_limits=self.grid_limits
)
for var_i, (var_name, var_unit, var_vrange) in enumerate(
zip(
Expand Down Expand Up @@ -442,7 +443,8 @@ def create_metric_log_dict(self, metric_tensor, prefix, metric_name):
"""
log_dict = {}
metric_fig = vis.plot_error_map(
metric_tensor, self.config_loader, step_length=self.step_length
metric_tensor, self.config_loader, step_length=self.step_length,
grid_limits=self.grid_limits
)
full_log_name = f"{prefix}_{metric_name}"
log_dict[full_log_name] = wandb.Image(metric_fig)
Expand Down Expand Up @@ -532,6 +534,8 @@ def on_test_epoch_end(self):
loss_map,
self.config_loader,
title=f"Test loss, t={t_i} ({self.step_length * t_i} h)",

grid_limits=self.grid_limits
)
for t_i, loss_map in zip(
self.args.val_steps_to_log, mean_spatial_loss
Expand All @@ -544,7 +548,9 @@ def on_test_epoch_end(self):

# also make without title and save as pdf
pdf_loss_map_figs = [
vis.plot_spatial_error(loss_map, self.config_loader)
vis.plot_spatial_error(loss_map, self.config_loader,
grid_limits=self.grid_limits
)
for loss_map in mean_spatial_loss
]
pdf_loss_maps_dir = os.path.join(wandb.run.dir, "spatial_loss_maps")
Expand Down
10 changes: 10 additions & 0 deletions neural_lam/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,15 @@ def loads_file(fn):
device=device,
) # (d_f,)

raw_coords = np.load(os.path.join(static_dir_path, "nwp_xy.npy"))
interior_coords = raw_coords.reshape(2, -1)[:,interior_mask.numpy()]
grid_limits = [
interior_coords[0].min(),
interior_coords[0].max(),
interior_coords[1].min(),
interior_coords[1].max(),
]

return {
"boundary_mask": boundary_mask,
"interior_mask": interior_mask,
Expand All @@ -88,6 +97,7 @@ def loads_file(fn):
"data_mean": data_mean,
"data_std": data_std,
"param_weights": param_weights,
"grid_limits": grid_limits,
}


Expand Down
15 changes: 8 additions & 7 deletions neural_lam/vis.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def plot_on_axis(
vmax=None,
ax_title=None,
cmap="plasma",
grid_limits=None
):
"""
Plot weather state on given axis
Expand All @@ -84,17 +85,17 @@ def plot_on_axis(
mask_reshaped.clamp(0.7, 1).cpu().numpy()
) # Faded border region

ax.set_global()
ax.coastlines() # Add coastline outlines
data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy().T
data_grid = data.reshape(*data_config.grid_shape_state).cpu().numpy()
im = ax.imshow(
data_grid,
origin="lower",
alpha=pixel_alpha,
vmin=vmin,
vmax=vmax,
cmap=cmap,
) # TODO Do we not need extent and transform arguments here?
extent=grid_limits
)

if ax_title:
ax.set_title(ax_title, size=15)
Expand All @@ -103,7 +104,7 @@ def plot_on_axis(

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_prediction(
pred, target, data_config, obs_mask=None, title=None, vrange=None
pred, target, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None
):
"""
Plot example prediction and grond truth.
Expand All @@ -125,7 +126,7 @@ def plot_prediction(

# Plot pred and target
for ax, data in zip(axes, (target, pred)):
im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax)
im = plot_on_axis(ax, data, data_config, obs_mask, vmin, vmax, grid_limits=grid_limits)

# Ticks and labels
axes[0].set_title("Ground Truth", size=15)
Expand All @@ -141,7 +142,7 @@ def plot_prediction(

@matplotlib.rc_context(utils.fractional_plot_bundle(1))
def plot_spatial_error(
error, data_config, obs_mask=None, title=None, vrange=None
error, data_config, obs_mask=None, title=None, vrange=None, grid_limits=None
):
"""
Plot errors over spatial map
Expand All @@ -159,7 +160,7 @@ def plot_spatial_error(
subplot_kw={"projection": data_config.coords_projection},
)

im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd")
im = plot_on_axis(ax, error, data_config, obs_mask, vmin, vmax, cmap="OrRd", grid_limits=grid_limits)

# Ticks and labels
cbar = fig.colorbar(im, aspect=30)
Expand Down
25 changes: 21 additions & 4 deletions tests/test_mllam_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
# Third-party
import pooch
import pytest
import matplotlib.pyplot as plt

# First-party
from neural_lam.build_graph import main as build_graph
from neural_lam.config import Config
from neural_lam.train_model import main as train_model
from neural_lam.utils import load_static_data
from neural_lam.weather_dataset import WeatherDataset
from neural_lam.vis import plot_prediction

# Disable weights and biases to avoid unnecessary logging
# and to avoid having to deal with authentication
Expand Down Expand Up @@ -67,7 +69,9 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
n_prediction_timesteps = dataset.sample_length - n_input_steps

static_data = load_static_data(dataset_name)
n_grid = static_data["interior_mask"].sum().item()
nx, ny = config.values["grid_shape_state"]
n_grid = nx * ny
static_data["interior_mask"].sum().item()
n_boundary = static_data["boundary_mask"].sum().item()

# check that the dataset is not empty
Expand Down Expand Up @@ -107,13 +111,14 @@ def test_load_reduced_meps_dataset(meps_example_reduced_filepath):
}

# check the sizes of the props
# TODO Should this config not be for only interior?
nx, ny = config.values["grid_shape_state"]
assert n_grid + n_boundary == nx * ny
assert static_data["grid_static_features"].shape == (
n_grid,
n_grid_static_features,
)
assert static_data["boundary_static_features"].shape == (
n_boundary,
n_grid_static_features, # TODO Adjust dimensionality
)
assert static_data["step_diff_mean"].shape == (n_state_features,)
assert static_data["step_diff_std"].shape == (n_state_features,)
assert static_data["data_mean"].shape == (n_state_features,)
Expand Down Expand Up @@ -150,3 +155,15 @@ def test_train_model_reduced_meps_dataset():
"--n_example_pred=0",
]
train_model(args)


def test_vis_reduced_meps_dataset(meps_example_reduced_filepath):
data_config_file = meps_example_reduced_filepath / "data_config.yaml"
dataset_name = meps_example_reduced_filepath.name

config = Config.from_file(str(data_config_file))

static_data = load_static_data(dataset_name)
geopotential = static_data["grid_static_features"][..., 2]

plot_prediction(geopotential, geopotential, config, grid_limits=static_data["grid_limits"])

0 comments on commit c7eddb2

Please sign in to comment.