Skip to content

Commit

Permalink
cleanup doctrings
Browse files Browse the repository at this point in the history
  • Loading branch information
Leif Denby committed Aug 15, 2024
1 parent a955cee commit 0a79c74
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 154 deletions.
14 changes: 3 additions & 11 deletions neural_lam/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,9 @@


class BaseDatastore(abc.ABC):
"""Base class for weather
data used in the neural-
lam package. A datastore
defines the interface for
accessing weather data by
providing methods to
access the data in a
processed format that can
be used for training and
evaluation of neural
networks.
"""Base class for weather data used in the neural- lam package. A datastore defines
the interface for accessing weather data by providing methods to access the data in
a processed format that can be used for training and evaluation of neural networks.
NOTE: All methods return either primitive types, `numpy.ndarray`,
`xarray.DataArray` or `xarray.Dataset` objects, not `pytorch.Tensor`
Expand Down
95 changes: 24 additions & 71 deletions neural_lam/datastore/mllam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,24 +19,11 @@ class MLLAMDatastore(BaseCartesianDatastore):
"""Datastore class for the MLLAM dataset."""

def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
"""Construct a new
MLLAMDatastore from
the configuration file
at `config_path`. A
boundary mask is
created with
`n_boundary_points`
boundary points. If
`reuse_existing` is
True, the dataset is
loaded from a zarr
file if it exists
(unless the config has
been modified since
the zarr was created),
otherwise it is
created from the
configuration file.
"""Construct a new MLLAMDatastore from the configuration file at `config_path`.
A boundary mask is created with `n_boundary_points` boundary points. If
`reuse_existing` is True, the dataset is loaded from a zarr file if it exists
(unless the config has been modified since the zarr was created), otherwise it
is created from the configuration file.
Parameters
----------
Expand Down Expand Up @@ -74,6 +61,11 @@ def __init__(self, config_path, n_boundary_points=30, reuse_existing=True):
self._ds.to_zarr(fp_ds)
self._n_boundary_points = n_boundary_points

print("Training with the following features:")
for category in ["state", "forcing", "static"]:
if len(self.get_vars_names(category)) > 0:
print(f"{category}: {' '.join(self.get_vars_names(category))}")

@property
def root_path(self) -> Path:
"""The root path of the dataset.
Expand Down Expand Up @@ -166,24 +158,11 @@ def get_num_data_vars(self, category: str) -> int:
return len(self.get_vars_names(category))

def get_dataarray(self, category: str, split: str) -> xr.DataArray:
"""Return the
processed data (as a
single `xr.DataArray`)
for the given category
of data and
test/train/val-split
that covers all the
data (in space and
time) of a given
category (state/forcin
g/static). "state" is
the only required
category, for other
categories, the method
will return `None` if
the category is not
found in the
datastore.
"""Return the processed data (as a single `xr.DataArray`) for the given category
of data and test/train/val-split that covers all the data (in space and time) of
a given category (state/forcin g/static). "state" is the only required category,
for other categories, the method will return `None` if the category is not found
in the datastore.
The returned dataarray will at minimum have dimensions of `(grid_index,
{category}_feature)` so that any spatial dimensions have been stacked
Expand Down Expand Up @@ -236,23 +215,10 @@ def get_dataarray(self, category: str, split: str) -> xr.DataArray:
return da_category.sel(time=slice(t_start, t_end))

def get_normalization_dataarray(self, category: str) -> xr.Dataset:
"""Return the
normalization
dataarray for the
given category. This
should contain a
`{category}_mean` and
`{category}_std`
variable for each
variable in the
category. For
`category=="state"`,
the dataarray should
also contain a
`state_diff_mean` and
`state_diff_std`
variable for the one-
step differences of
"""Return the normalization dataarray for the given category. This should
contain a `{category}_mean` and `{category}_std` variable for each variable in
the category. For `category=="state"`, the dataarray should also contain a
`state_diff_mean` and `state_diff_std` variable for the one- step differences of
the state variables.
Parameters
Expand Down Expand Up @@ -283,24 +249,11 @@ def get_normalization_dataarray(self, category: str) -> xr.Dataset:

@property
def boundary_mask(self) -> xr.DataArray:
"""Produce a 0/1 mask
for the boundary
points of the dataset,
these will sit at the
edges of the domain
(in x/y extent) and
will be used to mask
out the boundary
points from the loss
function and to
overwrite the boundary
points from the
prediction. For now
this is created when
the mask is requested,
but in the future this
could be saved to the
zarr file.
"""Produce a 0/1 mask for the boundary points of the dataset, these will sit at
the edges of the domain (in x/y extent) and will be used to mask out the
boundary points from the loss function and to overwrite the boundary points from
the prediction. For now this is created when the mask is requested, but in the
future this could be saved to the zarr file.
Returns
-------
Expand Down
51 changes: 11 additions & 40 deletions neural_lam/datastore/multizarr/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,19 +18,10 @@ class MultiZarrDatastore(BaseCartesianDatastore):
DIMS_TO_KEEP = {"time", "grid_index", "variable_name"}

def __init__(self, config_path):
"""Create a multi-zarr
datastore from the
given configuration
file. The
configuration file
should be a YAML file,
the format of which is
should be inferred
from the example
configuration file in
`tests/datastore_examp
les/multizarr/data_con
fig.yml`.
"""Create a multi-zarr datastore from the given configuration file. The
configuration file should be a YAML file, the format of which is should be
inferred from the example configuration file in `tests/datastore_examp
les/multizarr/data_con fig.yml`.
Parameters
----------
Expand Down Expand Up @@ -390,33 +381,13 @@ def get_xy(self, category, stacked=True):

@functools.lru_cache()
def get_normalization_dataarray(self, category: str) -> xr.Dataset:
"""Return the
normalization
dataarray for the
given category. This
should contain a
`{category}_mean` and
`{category}_std`
variable for each
variable in the
category. For
`category=="state"`,
the dataarray should
also contain a
`state_diff_mean` and
`state_diff_std`
variable for the one-
step differences of
the state variables.
The return dataarray
should at least have
dimensions of `({categ
ory}_feature)`, but
can also include for
example `grid_index`
(if the normalisation
is done per grid point
for example).
"""Return the normalization dataarray for the given category. This should
contain a `{category}_mean` and `{category}_std` variable for each variable in
the category. For `category=="state"`, the dataarray should also contain a
`state_diff_mean` and `state_diff_std` variable for the one- step differences of
the state variables. The return dataarray should at least have dimensions of
`({categ ory}_feature)`, but can also include for example `grid_index` (if the
normalisation is done per grid point for example).
Parameters
----------
Expand Down
40 changes: 8 additions & 32 deletions neural_lam/datastore/npyfiles/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,21 +281,10 @@ def get_dataarray(self, category: str, split: str) -> DataArray:
def _get_single_timeseries_dataarray(
self, features: List[str], split: str, member: int = None
) -> DataArray:
"""Get the data array
spanning the complete
time series for a
given set of features
and split of data. For
state features the
`member` argument
should be specified to
select the ensemble
member to load. The
data will be loaded
using dask.delayed, so
that the data isn't
actually loaded until
it's needed.
"""Get the data array spanning the complete time series for a given set of
features and split of data. For state features the `member` argument should be
specified to select the ensemble member to load. The data will be loaded using
dask.delayed, so that the data isn't actually loaded until it's needed.
Parameters
----------
Expand Down Expand Up @@ -614,23 +603,10 @@ def boundary_mask(self) -> xr.DataArray:
return da_mask_stacked_xy

def get_normalization_dataarray(self, category: str) -> xr.Dataset:
"""Return the
normalization
dataarray for the
given category. This
should contain a
`{category}_mean` and
`{category}_std`
variable for each
variable in the
category. For
`category=="state"`,
the dataarray should
also contain a
`state_diff_mean` and
`state_diff_std`
variable for the one-
step differences of
"""Return the normalization dataarray for the given category. This should
contain a `{category}_mean` and `{category}_std` variable for each variable in
the category. For `category=="state"`, the dataarray should also contain a
`state_diff_mean` and `state_diff_std` variable for the one- step differences of
the state variables.
Parameters
Expand Down

0 comments on commit 0a79c74

Please sign in to comment.