Skip to content

Commit

Permalink
Common dataset class (#115)
Browse files Browse the repository at this point in the history
* first spilt of stats functions

* finishing stats computation isolation

* lint

* create dataAccesor class and implem

* factoring out attributes from titan and poesy datasets

* separating titan_cli from titan dataset file

* adding poesy cli and spearating from main dataset

* lint

* named tensors isolated + rationalizing imports

* lint again

* fixing imports

* reformat

* lint2

* corrections

* poesy ok with hardcoded stats name

* titan small changes

* correction get_dataset_path

* corrections

* factorize sample list

* fix imports

* remove named_tensors from py4cast (now mfai)

* named tensor from mfai

* factorize get_param_tensor

* fix poesy tensor shape

* moving concrete datasets to cli only

* lint

* fix Titan, Dummy and lint

* fix imports and dataset cli

* fix imports

* fix import Titan

* typo failing test models

* create dummy config

* lint

* fix dummy dataset for tests + setup class methods

* fix imports

* fix dir for dummy config

* lint

* fix config file

* fix dummy config file name

* simplify param list with accessor

* removing useless imports

* flexibilised dataset name

* lint

* remove registry dependency on train

* Update add_features_contribute.md

* Update add_features_contribute.md

* modify parameter naming, fix bugs in input-output distrbution. Working with arpege-lbc

* lint

* Update py4cast/datasets/base.py

Co-authored-by: Lea Berthomier <[email protected]>

* modifying titan lbc to compat with refacto

* add doc to exists function

* lint

* lint and add suggestion

* Update py4cast/datasets/access.py

Co-authored-by: colon3ltocard <[email protected]>

* Update py4cast/datasets/access.py

Co-authored-by: colon3ltocard <[email protected]>

* fix naming of aggregate /stat_name

* remove notImplementedError from decorated abstract methods

---------

Co-authored-by: Corentin <[email protected]>
Co-authored-by: Corentin <[email protected]>
Co-authored-by: Lea Berthomier <[email protected]>
Co-authored-by: colon3ltocard <[email protected]>
  • Loading branch information
5 people authored Jan 7, 2025
1 parent c13f24c commit 9e29fd2
Show file tree
Hide file tree
Showing 20 changed files with 1,733 additions and 3,178 deletions.
2 changes: 0 additions & 2 deletions bin/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from py4cast.datasets import registry as dataset_registry
from py4cast.datasets.base import TorchDataloaderSettings
from py4cast.lightning import (
ArLightningHyperParam,
Expand Down Expand Up @@ -56,7 +55,6 @@
type=str,
default="titan",
help="Dataset to use",
choices=dataset_registry.keys(),
)
parser.add_argument(
"--dataset_conf",
Expand Down
7 changes: 3 additions & 4 deletions config/datasets/titan_lbc.json
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
"train": {
"start": 2021010103,
"end": 2022123123,
"step": 1
"step_duration": 3
},
"valid": {
"start": 2023010103,
"end": 2023123123,
"step": 3
"step_duration": 3
},
"test": {
"start": 2023010103,
"end": 2023123123,
"step": 3
"step_duration": 3
}
},
"grid":{
Expand All @@ -24,7 +24,6 @@
"projection_kwargs": {}
},
"settings":{
"step_duration": 1,
"standardize": true,
"file_format": "npy"
},
Expand Down
38 changes: 21 additions & 17 deletions doc/add_features_contribute.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,33 @@ We provide an example module [here](../py4cast_plugin_example.py) to help you cr

### Adding a new dataset

1. Your dataset **MUST** inherit from **DatasetABC** and **data.Dataset**, in thats order.
All datasets inherit the pytorch Dataset class and a specific DatasetABC class that handles "weather-like" inputs.
The user only needs to define the methods accessing data on disk and subclass the DataAccessor abstract base class.

```python
class TitanDataset(DatasetABC, Dataset):
...
```
You must decide how to define, for your dataset :
* a parameter **WeatherParam**, i.e a 2D field corresponding to a weather variable (e.g geopotential at 500hPa), including long name, level, type of level
* a grid **Grid** i.e the way a 2D field is represented on the geoid. Regular latitude-longitude is one exemple, but it can be something else (Lambert, etc...).
* a **Period** (contiguous array of dates and hours describing the temporal extent of your dataset).
* a **Timestamp**, i.e, for a single sample, how to define the dates and hours involved. In an autoregressive framework, one is typically interested in a day D, hour h and previous / subsequent timesteps : D:h-dt,D:h,D:h+dt, etc...
* **Stats** : normalization constants used to pre-process the inputs ; these are metadata which can be pre-computed and stored in a `cache_dir`.

2. Your dataset **MUST** implement
* ALL the abstract properties from **DatasetABC**: `dataset_info`, `meshgrid`, `geopotential_info` and `cache_dir`.
* ALL the abstract methods from **DatasetABC**: `torch_dataloader` and `from_json`.
* 2 methods from **data.Dataset**: `__len__` and `__getitem__`.
This is done with the subsequent steps :

3. Your `__getitem__` method MUST return an Item object.
1. Your DataAccessor **MUST** inherit from **DatasetAccessor**.

```python
def __getitem__(self, index: int) -> Item:
...
class TitanAccessor(DataAccessor):
...
```

4. It is MANDATORY that your `__getitem__` method returns [Item](../py4cast/datasets/base.py#L288) instances containing NamedTensors with precise feature and dimension names. By convention we use these names for tensor dimensions: **("batch", "timestep", "lat", "lon", "features")**.

5. It is HIGHLY RECOMMENDED that your dataset implements a `prepare` method to easily compute and save all the statics needed by your dataset.

2. Your DataAccessor **MUST** implement all the abstract methods from **DataAccessor**.
* `load_grid_info`, `load_param_info`, `get_weight_per_level` : deciding how to load metadata
* `get_dataset_path`, `cache_dir`, `get_filepath` : deciding how to describe directories and files in your architecture.
* `load_data_from_disk`, `exists`, `valid_timestamps` : used to verify the validity of a sample and to load individual data chunks from the disk.
3. This code must be included in a submodule under the py4cast.datasets module, with :
* the `__init__.py` file containing the definition of the DataAccessor class
* additional files, such as `settings.py` (defining constants such as directories) or `metadata.yaml` (providing general information on the dataset). While this is up to the user, we recommend following examples from `titan`(reanalysis dataset) or from `poesy` (ensemble reforecast dataset).

4. You must modify the `registry` object in the `py4cast.datasets.__init__.py` module to include your custom DataAccessor. After that, the DataAccessor will be used each time your dataset name will include the dataAccessor name as a substring.

### Adding Training Plots

Expand Down
73 changes: 31 additions & 42 deletions py4cast/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from pathlib import Path
from typing import Dict, Tuple, Union

from py4cast.settings import DEFAULT_CONFIG_DIR

from .base import DatasetABC # noqa: F401

registry = {}
Expand All @@ -14,59 +16,37 @@
# break the code
# NEW DATASETS MUST BE REGISTERED HERE


default_config_root = Path(__file__).parents[2] / "config/datasets/"


try:
from .smeagol import SmeagolDataset

registry["smeagol"] = (SmeagolDataset, default_config_root / "smeagol.json")
except ImportError:
warnings.warn(f"Could not import SmeagolDataset. {traceback.format_exc()}")

try:
from .smeagol import InferSmeagolDataset
from .titan import TitanAccessor

registry["smeagol_infer"] = (
InferSmeagolDataset,
default_config_root / "smeagol.json",
registry["titan"] = (
TitanAccessor,
DEFAULT_CONFIG_DIR / "datasets" / "titan_refacto.json",
)
except ImportError:
warnings.warn(f"Could not import SmeagolDataset. {traceback.format_exc()}")


try:
from .titan import TitanDataset

registry["titan"] = (TitanDataset, default_config_root / "titan_full.json")

except (ImportError, FileNotFoundError, ModuleNotFoundError):
warnings.warn(f"Could not import TitanDataset. {traceback.format_exc()}")

try:
from .poesy import PoesyDataset

registry["poesy"] = (PoesyDataset, default_config_root / "poesy_refacto.json")
except ImportError:
warnings.warn(f"Could not import PoesyDataset. {traceback.format_exc()}")
warnings.warn(f"Could not import TitanAccessor. {traceback.format_exc()}")

try:
from .poesy import InferPoesyDataset
from .poesy import PoesyAccessor

registry["poesy_infer"] = (
InferPoesyDataset,
default_config_root / "poesy_infer.json",
registry["poesy"] = (
PoesyAccessor,
DEFAULT_CONFIG_DIR / "datasets" / "poesy_refacto.json",
)

except ImportError:
warnings.warn(f"Could not import InferPoesyDataset. {traceback.format_exc()}")
warnings.warn(f"Could not import PoesyAccessor. {traceback.format_exc()}")

try:
from .dummy import DummyDataset
from .dummy import DummyAccessor

registry["dummy"] = (DummyDataset, "")
registry["dummy"] = (
DummyAccessor,
DEFAULT_CONFIG_DIR / "datasets" / "dummy_config.json",
)
except ImportError:
warnings.warn(f"Could not import DummyDataset. {traceback.format_exc()}")
warnings.warn(f"Could not import DummyAccessor. {traceback.format_exc()}")


def get_datasets(
Expand All @@ -83,16 +63,25 @@ def get_datasets(
Returns 3 instances of the dataset: train, val, test
"""

# checks if name has a registry key as component (substring)
registered_name = ""
for k in registry:
if k in name.lower():
registered_name = k
break
try:
dataset_kls, default_config = registry[name]
accessor_kls, default_config = registry[registered_name]
except KeyError as ke:
raise ValueError(
f"Dataset {name} not found in registry, available datasets are :{registry.keys()}"
f"Dataset {name} doesn't match a registry substring, available datasets are :{registry.keys()}"
) from ke

config_file = default_config if config_file is None else Path(config_file)

return dataset_kls.from_json(
return DatasetABC.from_json(
accessor_kls,
name,
config_file,
num_input_steps,
num_pred_steps_train,
Expand Down
Loading

0 comments on commit 9e29fd2

Please sign in to comment.