From bf8172a4d75414a236ae6b1a61287588bda80e86 Mon Sep 17 00:00:00 2001 From: Leif Denby Date: Thu, 12 Sep 2024 11:32:46 +0200 Subject: [PATCH] MLLAMDatastore -> MDPDatastore --- neural_lam/datastore/__init__.py | 4 ++-- neural_lam/datastore/{mllam.py => mdp.py} | 12 +++++++++--- neural_lam/train_model.py | 4 ++-- 3 files changed, 13 insertions(+), 7 deletions(-) rename neural_lam/datastore/{mllam.py => mdp.py} (96%) diff --git a/neural_lam/datastore/__init__.py b/neural_lam/datastore/__init__.py index 479d31a9..901841db 100644 --- a/neural_lam/datastore/__init__.py +++ b/neural_lam/datastore/__init__.py @@ -1,9 +1,9 @@ # Local -from .mllam import MLLAMDatastore # noqa +from .mdp import MDPDatastore # noqa from .npyfiles import NpyFilesDatastore # noqa DATASTORES = dict( - mllam=MLLAMDatastore, + mdp=MDPDatastore, npyfiles=NpyFilesDatastore, ) diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mdp.py similarity index 96% rename from neural_lam/datastore/mllam.py rename to neural_lam/datastore/mdp.py index b0867d02..f40a74a5 100644 --- a/neural_lam/datastore/mllam.py +++ b/neural_lam/datastore/mdp.py @@ -15,12 +15,18 @@ from .base import BaseCartesianDatastore, CartesianGridShape -class MLLAMDatastore(BaseCartesianDatastore): - """Datastore class for the MLLAM dataset.""" +class MDPDatastore(BaseCartesianDatastore): + """ + Datastore class for datasets made with the mllam_data_prep library + (https://github.com/mllam/mllam-data-prep). This class wraps the + `mllam_data_prep` library to do the necessary transforms to create the + different categories (state/forcing/static) of data, with the actual + transform to do being specified in the configuration file. + """ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True): """ - Construct a new MLLAMDatastore from the configuration file at + Construct a new MDPDatastore from the configuration file at `config_path`. A boundary mask is created with `n_boundary_points` boundary points. If `reuse_existing` is True, the dataset is loaded from a zarr file if it exists (unless the config has been modified diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 890f80fe..0913319e 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -11,7 +11,7 @@ # Local from . import utils -from .datastore import init_datastore +from .datastore import DATASTORES, init_datastore from .models import GraphLAM, HiLAM, HiLAMParallel from .weather_dataset import WeatherDataModule @@ -30,7 +30,7 @@ def main(input_args=None): parser.add_argument( "datastore_kind", type=str, - choices=["npyfiles", "mllam"], + choices=DATASTORES.keys(), help="Kind of datastore to use", ) parser.add_argument(