Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial design of mllam-verification package #1

Open
mafdmi opened this issue Jan 7, 2025 · 16 comments
Open

Initial design of mllam-verification package #1

mafdmi opened this issue Jan 7, 2025 · 16 comments
Assignees

Comments

@mafdmi
Copy link
Collaborator

mafdmi commented Jan 7, 2025

I've now started the development of the inference vs persistence plots. I thought I would just provide my initial ideas of the design of the package, just to be aligned on these ideas before coding too much. Please provide any comments/thoughts/feedback!:)

  1. First of all, I thought this could be the start on the development of a verification package for the mllam community, hence the name of the repo "mllam-verification".
  2. I tried to follow and copy the repo structure of mllam-dataprep as a starting point.
  3. I've defined an initial config file example with the following layout:
schema_version: v0.1.0

inputs:
  datasets:
    initial:
      path: /path/to/initial.zarr
    target:
      path: /path/to/target.zarr
    prediction:
      path: /path/to/prediction.zarr
  variables:
    - 2t
    - 10u
  coord_ranges:
    time:
      start: 1990-09-03T00:00
      end: 1990-09-09T00:00
      step: PT3H

methods:
  - global_persistence
  - gridpoint_persistence

output:
  path: /path/to/output/directory
  1. Since I like pydantic, I propose to use that for validation.

Some thoughts related to this structure:

  • I was about to call the "target" dataset for "truth", but since I thought we would also like to use the package to e.g. verify the inference of one model against the inference of another, I went with "target". So, the "target" dataset is what we want to verify against, and the "prediction" dataset is what we want to verify.
  • I added "coord_ranges" to make it possible to only verify a subset e.g. in time or space.
  • I propose that we use the same setup as we agreed upon for the statistics calculation in mllam-dataprep (see Add support for writing more composite statistics (e.g. grid-point based mean of time-step differences) mllam-data-prep#42). That is, we define what verification methods we want to calculate in the "methods" section. We verify that those methods can be imported from within the package when parsing the config, if not script will fail.
  • I thought that we would not only be interested in saving plots to disk, but also datasets with the verification metrics. For now, I've just added a "path" parameter to the "output" section, so I will just save the plots and the verification datasets to the same path. We can elaborate this if needed, e.g. what variables we want to save etc.
@mafdmi mafdmi self-assigned this Jan 7, 2025
@leifdenby
Copy link
Member

leifdenby commented Jan 7, 2025

This is cool! I hadn't considered using a config file as the interface to this, but what you suggest sounds good. I guess there will be some assumptions about the coordinates present in the input datasets (initial, target and prediction as you have them). Two quick thoughts on this:

  1. Might it be simpler to combine the "initial" and "target" into a "reference" dataset?
  2. If one was to compare a set of experiments (where there would be multiple "prediction" datasets) were you thinking we'd run the tool again, or would we adapt the config file?

It might be simpler to think of the function signature for the two plots you are going to create. For the global persistence something like the following might work:

def create_global_persistence_timeline_plot(
    ds_reference: xr.Dataset,
    prediction: Union[xr.Dataset,Dict[str,xr.Dataset]],
) -> matplotlib.pyplot.Figure:

or you could instead rethink this as a global timeline plot (where including a line for persistence could be an option):

def create_global_error_timeline_plot(
    ds_reference: xr.Dataset,
    prediction: Union[xr.Dataset,Dict[str,xr.Dataset]],
    include_persistence=True
) -> matplotlib.pyplot.Figure:

The idea with allowing the argument to be a Dict[str,xr.Dataset] would be to make it possible to provide multiple prediction datasets and the keys could then be the names that are used as labels in the line plot.

A last idea would be to separate the plotting of these validation plots from the calculation of the thing to plot. That would mean to return and xr.Dataset that represents the lines of the persistance plot, but then you can use the xr.DataArray.plot() routines for the plotting. That could look something like:

def calculate_global_error(
    ds_reference: xr.Dataset,
    prediction: Union[xr.Dataset,Dict[str,xr.Dataset]],
    include_persistence=True
) -> xr.Dataset:

ds_prediction_1 = ...
ds_prediction_2 = ...
ds_reference = ...

ds_error = calculate_global_error(ds_reference, prediction=dict(model1=ds_prediction_1, model2=ds_prediction_2), include_persistence=True)
# ds_error would contain an extra dimension called say "model_name" with values `[model1, model2, persistence]`
ds_error.t2m.plot(hue="model_name", x="elapsed_time")

I am not sure about how to composite these functions. But in my experience it is useful to 1) return the values (typically data-arrays) from calculations in case you want to save these values to file, 2) think about what assumptions you will make about the inputs and 3) how you'd like to use the functions you write (in for example a Jupyter notebook).

Just some ideas :)

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 8, 2025

Thanks for all the good ideas/thoughts!

1. Might it be simpler to combine the "initial" and "target" into a "reference" dataset?

I think that would make sense. I actually included the "initial" dataset mostly because you wrote it in your confluence notes. I thought, that the "initial" dataset would be identical (up to an interpolation) to the "target" dataset at time zero. And then, a comparison between the initial and the target at time zero wouldn't make much sense. Or do we actually perturb the initial conditions? There is probably something I've missed with the way neural-lam runs.

2. If one was to compare a set of experiments (where there would be multiple "prediction" datasets) were you thinking we'd run the tool again, or would we adapt the config file?

I think we should allow for multiple "prediction" datasets as you suggest below. Then I see two ways you can use the tool to compare a set of experiments:

a. Compare the multiple predictions to the same reference dataset
b. Define one of the predictions as the "reference" and compare the other predictions to this. This will give an anomaly like plot with the reference as a constant line, and the other predictions as deviations relative to this.

I think in both cases, you should adjust the config to set what the reference dataset is and what the prediction datasets are, and then run the tool again.

The idea with allowing the argument to be a Dict[str,xr.Dataset] would be to make it possible to provide multiple prediction datasets and the keys could then be the names that are used as labels in the line plot.

Make sense. I will adjust the pydantic config validation to allow for multiple prediction datasets.

A last idea would be to separate the plotting of these validation plots from the calculation of the thing to plot. That would mean to return and xr.Dataset that represents the lines of the persistance plot, but then you can use the xr.DataArray.plot() routines for the plotting.

Yes, I think that would be the best solution. Then we'll have some general plotting functions, which can plot datasets output by various calculation functions.

I am not sure about how to composite these functions. But in my experience it is useful to 1) return the values (typically data-arrays) from calculations in case you want to save these values to file, 2) think about what assumptions you will make about the inputs and 3) how you'd like to use the functions you write (in for example a Jupyter notebook).

Makes sense. I will talk with @elbdmi to get to know what I can expect the inference dataset to look like and define some assumptions based on that.

@leifdenby
Copy link
Member

Sounds good! To make the code easier for you to write here it might be best to simply make assumptions about the "reference" and "prediction" datasets you will be using.

For the global time-error plot we could assume that:

  • ds_reference is xr.Dataset with coordinates [time, grid_index]
  • ds_prediction is xr.Dataset with coordinates [analysis_time, elapsed_forecast_duration, grid_index]
  • the error (both persistence and prediction vs reference error) should be computed for all variables

The reason why I suggest grid_index rather than have [x, y] is that we may in future have data which is not given on a regular 2D grid and so it might as well make this as general as possible for now.

For the spatial error plot for a given elapsed time we could assume that:

  • ds_reference is xr.Dataset with coordinates [time, x, y]
  • ds_prediction is xr.Dataset with coordinates [analysis_time, elapsed_forecast_duration, x, y]
  • the error (both persistence and prediction vs reference error) should be computed for all variables

For the spatial plot we need to have the [x, y] coordinates of course :)

I think you've got the right idea to work out what the inputs should contain. I think I would define that in terms of the functions first (since then we can build on that in notebooks and separate functions that actually loads for disk) and then as @elbdmi keeps working we can constrain what is stored in the zarr datasets too.

It might be simpler to only allow calculations for a single prediction dataset for now. Then your code could be called to write a number of netCDF files which contain the output of the error calculations and then plotting could be done by combining multiple of these netCDF files. So this means dropping the idea with the Dict above :) I am probably over-complicating things with this idea.

Anyway, I think the key is to get the functions that compute the error data as xr.Datasets done and write some tests and make some plots with that. Looking forward to seeing what you and @elbdmi work out!

@elbdmi
Copy link

elbdmi commented Jan 8, 2025

For the ds_prediction I have a state variable with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature] but I could remove the state_feature

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 8, 2025

Sounds good. I will start working on it this afternoon.

* the error (both persistence and prediction vs reference error) should be computed for all variables

To be sure I understand you correctly, do you mean that we always should calculate the error for all variables, or is it fine to keep the

inputs
  ...
  variables:
    - 2t
    - 10u

part? I imagine, that if you don't specify the "variables" section, we calculate the error for all variables, but if you specify certain variables, we only calculate the error for those variables.

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 8, 2025

For the ds_prediction I have a state variable with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature] but I could remove the state_feature

What will the state_feature contain?

@elbdmi
Copy link

elbdmi commented Jan 8, 2025

State is the main variable containing predictions with dimensions [analysis_time, elapsed_forecast_duration, grid_index, state_feature]. The state_feature contains the specific variables or physical quantities that the model predicts for example 2t or 10u.

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 8, 2025

State is the main variable containing predictions with dimensions [analysis_time, elapsed_forecast_duration, grid_index, state_feature]. The state_feature contains the specific variables or physical quantities that the model predicts for example 2t or 10u.

Okay, so then in order for mllam-verification to be able to only calculate error of specific physical quantities (if we want to be able to do this), I guess we should keep the state_feature:)

@leifdenby
Copy link
Member

For the ds_prediction I have a state variable with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature] but I could remove the state_feature

Yes, sort of :) To do this the process would be to go from state_feature and split each feature out into a separate variable again. This would include getting the units and long_name attributes from the variables that contain those (state_feature_units and state_feature_long_name). In effect reversing the transformations that mllam-data-prep in going from source data and to transformed data ready to create training datasets. I started work on this on a branch a while ago https://github.com/leifdenby/mllam-data-prep/tree/feat/inverse-ops @elbdmi, maybe we can sit and look at this together next week? So what I'm saying is that I can help with going from the shape you are planning to get output into :) Does that sound ok ?

@leifdenby
Copy link
Member

I imagine, that if you don't specify the "variables" section, we calculate the error for all variables, but if you specify certain variables, we only calculate the error for those variables.

I was thinking that it might be good to make the functions that compute the error values work do the calculation on all variables in the dataset, but then when plotting your idea of picking the ones out that you'd like could be separate. If everything is loaded with dask (which it will be when using xr.open_dataset("path.zarr", chunks={}) or xr.open_zarr("path.zarr") then the execution of the calculation of those error values will be lazy anyway and the calculation wont take place until you are rendering the plot or writing the values to disk.

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 8, 2025

Okay, lets do that. I just thought that if we at some point get a lot of variables out from the inference, we would maybe not necessarily be interested in calculating and storing the error (or other verification metrics) for alle variables. But that's maybe a problem for the future (if a problem at all).

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 8, 2025

For the ds_prediction I have a state variable with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature] but I could remove the state_feature

Yes, sort of :) To do this the process would be to go from state_feature and split each feature out into a separate variable again. This would include getting the units and long_name attributes from the variables that contain those (state_feature_units and state_feature_long_name). In effect reversing the transformations that mllam-data-prep in going from source data and to transformed data ready to create training datasets. I started work on this on a branch a while ago https://github.com/leifdenby/mllam-data-prep/tree/feat/inverse-ops @elbdmi, maybe we can sit and look at this together next week? So what I'm saying is that I can help with going from the shape you are planning to get output into :) Does that sound ok ?

Would it then make sense to add a state_feature coordinate to the error datasets too? Like

  • ds_reference with coordinates [time, grid_index, state_feature]
  • ds_prediction with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature]

@leifdenby
Copy link
Member

Would it then make sense to add a state_feature coordinate to the error datasets too? Like

* ds_reference with coordinates [time, grid_index, **state_feature**]

* ds_prediction with coordinates [analysis_time, elapsed_forecast_duration, grid_index, **state_feature**]

It could be yes, but I was thinking we would try and make this codebase agnostic to whether one is using a machine learning model or not. Or said another way: If we make this tool in a way that it assumes that the inputs are (as close as possible to) CF-compliant in their contents (i.e. different physical fields in separate variables, with units, long_name and one day standard_name attributes) then it will be much easier to use with other tools (for example xr.DataArray.plot() automatically puts units and long_name on the plot). In that way this tool becomes a general purpose tool that can be used for model verification (we could have used HARP of course, but that is a codebase in R so difficult to interface with).

@elbdmi
Copy link

elbdmi commented Jan 8, 2025

For the ds_prediction I have a state variable with coordinates [analysis_time, elapsed_forecast_duration, grid_index, state_feature] but I could remove the state_feature

Yes, sort of :) To do this the process would be to go from state_feature and split each feature out into a separate variable again. This would include getting the units and long_name attributes from the variables that contain those (state_feature_units and state_feature_long_name). In effect reversing the transformations that mllam-data-prep in going from source data and to transformed data ready to create training datasets. I started work on this on a branch a while ago https://github.com/leifdenby/mllam-data-prep/tree/feat/inverse-ops @elbdmi, maybe we can sit and look at this together next week? So what I'm saying is that I can help with going from the shape you are planning to get output into :) Does that sound ok ?

@leifdenby, that sounds great! I’d love to sit down and go through this together next week. Reversing the transformations to split the state_feature into separate variables with attributes like units and long_name makes complete sense, and your branch seems like a good starting point for this work.

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 9, 2025

Concerning the spatial error plot would you prefer that we just save one plot per elapsed time, an animation/gif, a Hovmöller diagram or something else?

@mafdmi
Copy link
Collaborator Author

mafdmi commented Jan 10, 2025

I've now worked out tests that uses synthetic data to call calculate_global_error and calculate_error_per_gridpoint as well as plot functions to produce the following plots
error_map
error_timeline

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants