Skip to content

Commit

Permalink
log datastore config
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Denby committed Aug 14, 2024
1 parent 57bbb81 commit a955cee
Show file tree
Hide file tree
Showing 6 changed files with 75 additions and 27 deletions.
14 changes: 14 additions & 0 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Standard library
import abc
import collections
import dataclasses
from pathlib import Path
from typing import List, Union
Expand Down Expand Up @@ -59,6 +60,19 @@ def root_path(self) -> Path:
"""
pass

@property
@abc.abstractmethod
def config(self) -> collections.abc.Mapping:
"""The configuration of the datastore.
Returns
-------
collections.abc.Mapping
The configuration of the datastore, any dict like object can be returned.
"""
pass

@property
@abc.abstractmethod
def step_length(self) -> int:
Expand Down
12 changes: 12 additions & 0 deletions neural_lam/datastore/mllam.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ def root_path(self) -> Path:
"""
return self._root_path

@property
def config(self) -> mdp.Config:
"""The configuration of the dataset.
Returns
-------
mdp.Config
The configuration of the dataset.
"""
return self._config

@property
def step_length(self) -> int:
"""The length of the time steps in hours.
Expand Down
12 changes: 12 additions & 0 deletions neural_lam/datastore/multizarr/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,18 @@ def root_path(self):
"""
return self._root_path

@property
def config(self) -> dict:
"""Return the configuration dictionary.
Returns
-------
dict
The configuration dictionary.
"""
return self._config

def _normalize_path(self, path) -> str:
"""
Normalize the path of source-dataset defined in the configuration file.
Expand Down
46 changes: 20 additions & 26 deletions neural_lam/datastore/npyfiles/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,9 @@ def __init__(
self,
config_path,
):
"""Create a new
NpyFilesDatastore
using the
configuration file at
the given path. The
config file should be
a YAML file and will
be loaded into an
instance of the
`NpyDatastoreConfig`
dataclass.
"""Create a new NpyFilesDatastore using the configuration file at the given
path. The config file should be a YAML file and will be loaded into an instance
of the `NpyDatastoreConfig` dataclass.
Internally, the datastore uses dask.delayed to load the data from the
numpy files, so that the data isn't actually loaded until it's needed.
Expand All @@ -166,7 +158,7 @@ def __init__(

self._config_path = Path(config_path)
self._root_path = self._config_path.parent
self.config = NpyDatastoreConfig.from_yaml_file(self._config_path)
self._config = NpyDatastoreConfig.from_yaml_file(self._config_path)

@property
def root_path(self) -> Path:
Expand All @@ -181,21 +173,23 @@ def root_path(self) -> Path:
"""
return self._root_path

@property
def config(self) -> NpyDatastoreConfig:
"""The configuration for the datastore.
Returns
-------
NpyDatastoreConfig
The configuration for the datastore.
"""
return self._config

def get_dataarray(self, category: str, split: str) -> DataArray:
"""Get the data array
for the given category
and split of data. If
the category is
'state', the data
array will be a
concatenation of the
data arrays for all
ensemble members. The
data will be loaded as
a dask array, so that
the data isn't
actually loaded until
it's needed.
"""Get the data array for the given category and split of data. If the category
is 'state', the data array will be a concatenation of the data arrays for all
ensemble members. The data will be loaded as a dask array, so that the data
isn't actually loaded until it's needed.
Parameters
----------
Expand Down
4 changes: 3 additions & 1 deletion neural_lam/train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,9 @@ def main(input_args=None):
save_last=True,
)
logger = pl.loggers.WandbLogger(
project=args.wandb_project, name=run_name, config=args
project=args.wandb_project,
name=run_name,
config=dict(training=vars(args), datastore=datastore._config),
)
trainer = pl.Trainer(
max_epochs=args.epochs,
Expand Down
14 changes: 14 additions & 0 deletions tests/test_datastores.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
`xr.DataArray`) for the given category and test/train/val-split.
- [x] `boundary_mask` (property): Return the boundary mask for the dataset,
with spatial dimensions stacked.
- [x] `config` (property): Return the configuration of the datastore.
In addition BaseCartesianDatastore must have the following methods and attributes:
- [x] `get_xy_extent` (method): Return the extent of the x, y coordinates for a
Expand All @@ -27,6 +28,8 @@
"""

# Standard library
import collections
import dataclasses
from pathlib import Path

# Third-party
Expand All @@ -47,6 +50,17 @@ def test_root_path(datastore_name):
assert isinstance(datastore.root_path, Path)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_config(datastore_name):
"""Check that the `datastore.config` property is implemented."""
datastore = init_datastore(datastore_name)
# check the config is a mapping or a dataclass
config = datastore.config
assert isinstance(config, collections.abc.Mapping) or dataclasses.is_dataclass(
config
)


@pytest.mark.parametrize("datastore_name", DATASTORES.keys())
def test_step_length(datastore_name):
"""Check that the `datastore.step_length` property is implemented."""
Expand Down

0 comments on commit a955cee

Please sign in to comment.