From 71cfdf918de59b6a40b570508618ce7c9bcda867 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Wed, 4 Dec 2024 09:00:10 +0100 Subject: [PATCH 1/2] Fix evaluation example visualisation plots (#91) Fix bugs in recently introduced datastore functionality #66 (error in calculation in `BaseDatastore.get_xy_extent()` and overlooked in-place modification of config dict in `MDPDatastore.coords_projection`), and also fix issue in `ARModel.plot_examples` by using newly introduced (#66) `WeatherDataset.create_dataarray_from_tensor()` to create `xr.DataArray` from prediction tensor and calling plot methods directly on `xr.DataArray` rather than using bare numpy arrays with `matplotlib`. --- CHANGELOG.md | 4 ++ neural_lam/datastore/base.py | 9 +++- neural_lam/datastore/mdp.py | 5 ++- neural_lam/models/ar_model.py | 81 +++++++++++++++++++++++++++++++---- neural_lam/vis.py | 42 ++++++++---------- neural_lam/weather_dataset.py | 5 ++- pyproject.toml | 2 +- 7 files changed, 109 insertions(+), 39 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 12cf54f6..01d4cac9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 [\#66](https://github.com/mllam/neural-lam/pull/66) @leifdenby @sadamov +### Fixed + +- Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby + ## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0) ### Added diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 0317c2e5..b0055e39 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -295,8 +295,13 @@ def get_xy_extent(self, category: str) -> List[float]: The extent of the x, y coordinates. """ - xy = self.get_xy(category, stacked=False) - extent = [xy[0].min(), xy[0].max(), xy[1].min(), xy[1].max()] + xy = self.get_xy(category, stacked=True) + extent = [ + xy[:, 0].min(), + xy[:, 0].max(), + xy[:, 1].min(), + xy[:, 1].max(), + ] return [float(v) for v in extent] @property diff --git a/neural_lam/datastore/mdp.py b/neural_lam/datastore/mdp.py index 10593a82..0d1aac7b 100644 --- a/neural_lam/datastore/mdp.py +++ b/neural_lam/datastore/mdp.py @@ -1,4 +1,5 @@ # Standard library +import copy import warnings from functools import cached_property from pathlib import Path @@ -394,7 +395,9 @@ def coords_projection(self) -> ccrs.Projection: class_name = projection_info["class_name"] ProjectionClass = getattr(ccrs, class_name) - kwargs = projection_info["kwargs"] + # need to copy otherwise we modify the dict stored in the dataclass + # in-place + kwargs = copy.deepcopy(projection_info["kwargs"]) globe_kwargs = kwargs.pop("globe", {}) if len(globe_kwargs) > 0: diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index bc4c6719..44baf9c2 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -1,5 +1,6 @@ # Standard library import os +from typing import List, Union # Third-party import matplotlib.pyplot as plt @@ -7,12 +8,14 @@ import pytorch_lightning as pl import torch import wandb +import xarray as xr # Local from .. import metrics, vis from ..config import NeuralLAMConfig from ..datastore import BaseDatastore from ..loss_weighting import get_state_feature_weighting +from ..weather_dataset import WeatherDataset class ARModel(pl.LightningModule): @@ -147,6 +150,44 @@ def __init__( # For storing spatial loss maps during evaluation self.spatial_loss_maps = [] + def _create_dataarray_from_tensor( + self, + tensor: torch.Tensor, + time: Union[int, List[int]], + split: str, + category: str, + ) -> xr.DataArray: + """ + Create an `xr.DataArray` from a tensor, with the correct dimensions and + coordinates to match the datastore used by the model. This function in + in effect is the inverse of what is returned by + `WeatherDataset.__getitem__`. + + Parameters + ---------- + tensor : torch.Tensor + The tensor to convert to a `xr.DataArray` with dimensions [time, + grid_index, feature]. The tensor will be copied to the CPU if it is + not already there. + time : Union[int,List[int]] + The time index or indices for the data, given as integers or a list + of integers representing epoch time in nanoseconds. The ints will be + copied to the CPU memory if they are not already there. + split : str + The split of the data, either 'train', 'val', or 'test' + category : str + The category of the data, either 'state' or 'forcing' + """ + # TODO: creating an instance of WeatherDataset here on every call is + # not how this should be done but whether WeatherDataset should be + # provided to ARModel or where to put plotting still needs discussion + weather_dataset = WeatherDataset(datastore=self._datastore, split=split) + time = np.array(time.cpu(), dtype="datetime64[ns]") + da = weather_dataset.create_dataarray_from_tensor( + tensor=tensor.cpu().numpy(), time=time, category=category + ) + return da + def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) @@ -406,10 +447,13 @@ def test_step(self, batch, batch_idx): ) self.plot_examples( - batch, n_additional_examples, prediction=prediction + batch, + n_additional_examples, + prediction=prediction, + split="test", ) - def plot_examples(self, batch, n_examples, prediction=None): + def plot_examples(self, batch, n_examples, split, prediction=None): """ Plot the first n_examples forecasts from batch @@ -422,18 +466,34 @@ def plot_examples(self, batch, n_examples, prediction=None): prediction, target, _, _ = self.common_step(batch) target = batch[1] + time = batch[3] # Rescale to original data scale prediction_rescaled = prediction * self.state_std + self.state_mean target_rescaled = target * self.state_std + self.state_mean # Iterate over the examples - for pred_slice, target_slice in zip( - prediction_rescaled[:n_examples], target_rescaled[:n_examples] + for pred_slice, target_slice, time_slice in zip( + prediction_rescaled[:n_examples], + target_rescaled[:n_examples], + time[:n_examples], ): # Each slice is (pred_steps, num_grid_nodes, d_f) self.plotted_examples += 1 # Increment already here + da_prediction = self._create_dataarray_from_tensor( + tensor=pred_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + da_target = self._create_dataarray_from_tensor( + tensor=target_slice, + time=time_slice, + split=split, + category="state", + ).unstack("grid_index") + var_vmin = ( torch.minimum( pred_slice.flatten(0, 1).min(dim=0)[0], @@ -453,18 +513,20 @@ def plot_examples(self, batch, n_examples, prediction=None): var_vranges = list(zip(var_vmin, var_vmax)) # Iterate over prediction horizon time steps - for t_i, (pred_t, target_t) in enumerate( - zip(pred_slice, target_slice), start=1 - ): + for t_i, _ in enumerate(zip(pred_slice, target_slice), start=1): # Create one figure per variable at this time step var_figs = [ vis.plot_prediction( - pred=pred_t[:, var_i], - target=target_t[:, var_i], datastore=self._datastore, title=f"{var_name} ({var_unit}), " f"t={t_i} ({self._datastore.step_length * t_i} h)", vrange=var_vrange, + da_prediction=da_prediction.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), + da_target=da_target.isel( + state_feature=var_i, time=t_i - 1 + ).squeeze(), ) for var_i, (var_name, var_unit, var_vrange) in enumerate( zip( @@ -476,6 +538,7 @@ def plot_examples(self, batch, n_examples, prediction=None): ] example_i = self.plotted_examples + wandb.log( { f"{var_name}_example_{example_i}": wandb.Image(fig) diff --git a/neural_lam/vis.py b/neural_lam/vis.py index b9d18b39..d6b57f88 100644 --- a/neural_lam/vis.py +++ b/neural_lam/vis.py @@ -2,6 +2,7 @@ import matplotlib import matplotlib.pyplot as plt import numpy as np +import xarray as xr # Local from . import utils @@ -65,9 +66,9 @@ def plot_error_map(errors, datastore: BaseRegularGridDatastore, title=None): @matplotlib.rc_context(utils.fractional_plot_bundle(1)) def plot_prediction( - pred, - target, datastore: BaseRegularGridDatastore, + da_prediction: xr.DataArray = None, + da_target: xr.DataArray = None, title=None, vrange=None, ): @@ -79,8 +80,8 @@ def plot_prediction( """ # Get common scale for values if vrange is None: - vmin = min(vals.min().cpu().item() for vals in (pred, target)) - vmax = max(vals.max().cpu().item() for vals in (pred, target)) + vmin = min(da_prediction.min(), da_target.min()) + vmax = max(da_prediction.max(), da_target.max()) else: vmin, vmax = vrange @@ -88,10 +89,8 @@ def plot_prediction( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) - mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + mask_values = np.invert(da_mask.values.astype(bool)).astype(float) + pixel_alpha = mask_values.clip(0.7, 1) # Faded border region fig, axes = plt.subplots( 1, @@ -101,28 +100,23 @@ def plot_prediction( ) # Plot pred and target - for ax, data in zip(axes, (target, pred)): + for ax, da in zip(axes, (da_target, da_prediction)): ax.coastlines() # Add coastline outlines - data_grid = ( - data.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() - .numpy() - ) - im = ax.imshow( - data_grid, + da.plot.imshow( + ax=ax, origin="lower", + x="x", extent=extent, - alpha=pixel_alpha, + alpha=pixel_alpha.T, vmin=vmin, vmax=vmax, cmap="plasma", + transform=datastore.coords_projection, ) # Ticks and labels axes[0].set_title("Ground Truth", size=15) axes[1].set_title("Prediction", size=15) - cbar = fig.colorbar(im, aspect=30) - cbar.ax.tick_params(labelsize=10) if title: fig.suptitle(title, size=20) @@ -150,9 +144,7 @@ def plot_spatial_error( # Set up masking of border region da_mask = datastore.unstack_grid_coords(datastore.boundary_mask) mask_reshaped = da_mask.values - pixel_alpha = ( - mask_reshaped.clamp(0.7, 1).cpu().numpy() - ) # Faded border region + pixel_alpha = mask_reshaped.clip(0.7, 1) # Faded border region fig, ax = plt.subplots( figsize=(5, 4.8), @@ -161,8 +153,10 @@ def plot_spatial_error( ax.coastlines() # Add coastline outlines error_grid = ( - error.reshape(list(datastore.grid_shape_state.values.values())) - .cpu() + error.reshape( + [datastore.grid_shape_state.x, datastore.grid_shape_state.y] + ) + .T.cpu() .numpy() ) diff --git a/neural_lam/weather_dataset.py b/neural_lam/weather_dataset.py index 532e3c90..b5f85580 100644 --- a/neural_lam/weather_dataset.py +++ b/neural_lam/weather_dataset.py @@ -529,7 +529,8 @@ def create_dataarray_from_tensor( tensor : torch.Tensor The tensor to construct the DataArray from, this assumed to have the same dimension ordering as returned by the __getitem__ method - (i.e. time, grid_index, {category}_feature). + (i.e. time, grid_index, {category}_feature). The tensor will be + copied to the CPU before constructing the DataArray. time : datetime.datetime or list[datetime.datetime] The time or times of the tensor. category : str @@ -581,7 +582,7 @@ def _is_listlike(obj): coords["time"] = time da = xr.DataArray( - tensor.numpy(), + tensor.cpu().numpy(), dims=dims, coords=coords, ) diff --git a/pyproject.toml b/pyproject.toml index f0bc0851..fdcb7f3e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ dependencies = [ "torch>=2.3.0", "torch-geometric==2.3.1", "parse>=1.20.2", - "dataclass-wizard>=0.22.3", + "dataclass-wizard<0.31.0", "mllam-data-prep>=0.5.0", ] requires-python = ">=3.9" From 1a128266fd144eada9d8e74724402c6e0e84267d Mon Sep 17 00:00:00 2001 From: Joel Oskarsson Date: Fri, 6 Dec 2024 16:04:49 +0100 Subject: [PATCH 2/2] Change wandb env var to properly disable at start of testing (#94) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## Describe your changes Tests are not working properly as wandb without any user logged in are making them crash. This should not happen, since we disable WANDB with env var when running the tests, but that variable is wrong. This is an attempt to fix that, to make sure wandb is disabled when running tests. ## Issue Link None ## Type of change - [x] 🐛 Bug fix (non-breaking change that fixes an issue) - [ ] ✨ New feature (non-breaking change that adds functionality) - [ ] 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] 📖 Documentation (Addition or improvements to documentation) ## Checklist before requesting a review - [x] My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use `pull` with `--rebase` option if possible). - [ ] I have performed a self-review of my code - [ ] For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values - [ ] I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code - [ ] I have updated the [README](README.MD) to cover introduced code changes - [ ] I have added tests that prove my fix is effective or that my feature works - [x] I have given the PR a name that clearly describes the change, written in imperative form ([context](https://www.gitkraken.com/learn/git/best-practices/git-commit-message#using-imperative-verb-form)). - [x] I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee. ## Checklist for reviewers Each PR comes with its own improvements and flaws. The reviewer should check the following: - [ ] the code is readable - [ ] the code is well tested - [ ] the code is documented (including return types and parameters) - [ ] the code is easy to maintain ## Author checklist after completed review - [x] I have added a line to the CHANGELOG describing this change, in a section reflecting type of change (add section where missing): - *added*: when you have added new functionality - *changed*: when default behaviour of the code has been changed - *fixes*: when your contribution fixes a bug ## Checklist for assignee - [x] PR is up to date with the base branch - [x] the tests pass - [x] author has added an entry to the changelog (and designated the change as *added*, *changed* or *fixed*) - Once the PR is ready to be merged, squash commits and merge the PR. --- CHANGELOG.md | 2 ++ tests/conftest.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 01d4cac9..32961b16 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Fixed +- Fix wandb environment variable disabling wandb during tests. Now correctly uses WANDB_MODE=disabled. [\#94](https://github.com/mllam/neural-lam/pull/94) @joeloskarsson + - Fix bugs introduced with datastores functionality relating visualation plots [\#91](https://github.com/mllam/neural-lam/pull/91) @leifdenby ## [v0.2.0](https://github.com/joeloskarsson/neural-lam/releases/tag/v0.2.0) diff --git a/tests/conftest.py b/tests/conftest.py index 6f579621..5d799c73 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,7 +17,7 @@ # Disable weights and biases to avoid unnecessary logging # and to avoid having to deal with authentication -os.environ["WANDB_DISABLED"] = "true" +os.environ["WANDB_MODE"] = "disabled" DATASTORE_EXAMPLES_ROOT_PATH = Path("tests/datastore_examples")