diff --git a/neural_lam/datastore/base.py b/neural_lam/datastore/base.py index 1b662fa4..480476fe 100644 --- a/neural_lam/datastore/base.py +++ b/neural_lam/datastore/base.py @@ -1,5 +1,6 @@ # Standard library import abc +import collections import dataclasses from pathlib import Path from typing import List, Union @@ -59,6 +60,19 @@ def root_path(self) -> Path: """ pass + @property + @abc.abstractmethod + def config(self) -> collections.abc.Mapping: + """The configuration of the datastore. + + Returns + ------- + collections.abc.Mapping + The configuration of the datastore, any dict like object can be returned. + + """ + pass + @property @abc.abstractmethod def step_length(self) -> int: diff --git a/neural_lam/datastore/mllam.py b/neural_lam/datastore/mllam.py index 0d011e5e..d22f041a 100644 --- a/neural_lam/datastore/mllam.py +++ b/neural_lam/datastore/mllam.py @@ -86,6 +86,18 @@ def root_path(self) -> Path: """ return self._root_path + @property + def config(self) -> mdp.Config: + """The configuration of the dataset. + + Returns + ------- + mdp.Config + The configuration of the dataset. + + """ + return self._config + @property def step_length(self) -> int: """The length of the time steps in hours. diff --git a/neural_lam/datastore/multizarr/store.py b/neural_lam/datastore/multizarr/store.py index 18af8457..23b33fe2 100644 --- a/neural_lam/datastore/multizarr/store.py +++ b/neural_lam/datastore/multizarr/store.py @@ -55,6 +55,18 @@ def root_path(self): """ return self._root_path + @property + def config(self) -> dict: + """Return the configuration dictionary. + + Returns + ------- + dict + The configuration dictionary. + + """ + return self._config + def _normalize_path(self, path) -> str: """ Normalize the path of source-dataset defined in the configuration file. diff --git a/neural_lam/datastore/npyfiles/store.py b/neural_lam/datastore/npyfiles/store.py index 630a8dd0..cff20043 100644 --- a/neural_lam/datastore/npyfiles/store.py +++ b/neural_lam/datastore/npyfiles/store.py @@ -138,17 +138,9 @@ def __init__( self, config_path, ): - """Create a new - NpyFilesDatastore - using the - configuration file at - the given path. The - config file should be - a YAML file and will - be loaded into an - instance of the - `NpyDatastoreConfig` - dataclass. + """Create a new NpyFilesDatastore using the configuration file at the given + path. The config file should be a YAML file and will be loaded into an instance + of the `NpyDatastoreConfig` dataclass. Internally, the datastore uses dask.delayed to load the data from the numpy files, so that the data isn't actually loaded until it's needed. @@ -166,7 +158,7 @@ def __init__( self._config_path = Path(config_path) self._root_path = self._config_path.parent - self.config = NpyDatastoreConfig.from_yaml_file(self._config_path) + self._config = NpyDatastoreConfig.from_yaml_file(self._config_path) @property def root_path(self) -> Path: @@ -181,21 +173,23 @@ def root_path(self) -> Path: """ return self._root_path + @property + def config(self) -> NpyDatastoreConfig: + """The configuration for the datastore. + + Returns + ------- + NpyDatastoreConfig + The configuration for the datastore. + + """ + return self._config + def get_dataarray(self, category: str, split: str) -> DataArray: - """Get the data array - for the given category - and split of data. If - the category is - 'state', the data - array will be a - concatenation of the - data arrays for all - ensemble members. The - data will be loaded as - a dask array, so that - the data isn't - actually loaded until - it's needed. + """Get the data array for the given category and split of data. If the category + is 'state', the data array will be a concatenation of the data arrays for all + ensemble members. The data will be loaded as a dask array, so that the data + isn't actually loaded until it's needed. Parameters ---------- diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index 4f011b76..e819c403 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -280,7 +280,9 @@ def main(input_args=None): save_last=True, ) logger = pl.loggers.WandbLogger( - project=args.wandb_project, name=run_name, config=args + project=args.wandb_project, + name=run_name, + config=dict(training=vars(args), datastore=datastore._config), ) trainer = pl.Trainer( max_epochs=args.epochs, diff --git a/tests/test_datastores.py b/tests/test_datastores.py index 319c5a7c..512bc5a0 100644 --- a/tests/test_datastores.py +++ b/tests/test_datastores.py @@ -16,6 +16,7 @@ `xr.DataArray`) for the given category and test/train/val-split. - [x] `boundary_mask` (property): Return the boundary mask for the dataset, with spatial dimensions stacked. +- [x] `config` (property): Return the configuration of the datastore. In addition BaseCartesianDatastore must have the following methods and attributes: - [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a @@ -27,6 +28,8 @@ """ # Standard library +import collections +import dataclasses from pathlib import Path # Third-party @@ -47,6 +50,17 @@ def test_root_path(datastore_name): assert isinstance(datastore.root_path, Path) +@pytest.mark.parametrize("datastore_name", DATASTORES.keys()) +def test_config(datastore_name): + """Check that the `datastore.config` property is implemented.""" + datastore = init_datastore(datastore_name) + # check the config is a mapping or a dataclass + config = datastore.config + assert isinstance(config, collections.abc.Mapping) or dataclasses.is_dataclass( + config + ) + + @pytest.mark.parametrize("datastore_name", DATASTORES.keys()) def test_step_length(datastore_name): """Check that the `datastore.step_length` property is implemented."""