Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1 feature improved configuration and data structures #79

Draft
wants to merge 89 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
89 commits
Select commit Hold shift + click to select a range
639025b
feat: dataclasses for training config
theissenhelen Aug 16, 2024
ffbdff4
test: add test for dataclasses
theissenhelen Aug 16, 2024
ad9c178
refactor: unused config values
theissenhelen Sep 10, 2024
bae570c
feat: structured TrainingConfig
theissenhelen Sep 23, 2024
76cd43c
fix: training config not in configstore
theissenhelen Sep 23, 2024
194bac8
feat: HardwareConfig
theissenhelen Sep 23, 2024
3cb472a
fix: missing base config attributes
theissenhelen Sep 24, 2024
cbe84bd
fix: variable names (temporary fix only)
theissenhelen Sep 24, 2024
01d7bc0
feat: add data config schema
theissenhelen Sep 24, 2024
7d5a047
feat: add structured config for gnn
theissenhelen Sep 25, 2024
83fc4d2
feat: structured configs for transformer and graphtransformer
theissenhelen Sep 25, 2024
a623237
feat: extended config schema for model architectures
theissenhelen Sep 26, 2024
6f23137
feat: add diagnostics structured config
theissenhelen Oct 9, 2024
3ffaca2
feat: translate hardware config to pydantic
theissenhelen Oct 24, 2024
de33e0c
feat: translate data to pydantic
theissenhelen Oct 24, 2024
b2f192e
feat: translate training and diagostic to pydantic
theissenhelen Oct 24, 2024
9a05ef3
fix: hydra instantiation
theissenhelen Oct 25, 2024
c98a90b
feat: translate gnn config to pydantic
theissenhelen Oct 30, 2024
25d2c2f
fix: config setup working
theissenhelen Oct 30, 2024
54c731d
fix: type hints
theissenhelen Oct 30, 2024
f953c31
refactor: remove model component
theissenhelen Oct 30, 2024
b352fa1
feat: translate transformer config to pydantic
theissenhelen Oct 31, 2024
caf7b1b
feat: translate GraphTransformerConfig to pydantic
theissenhelen Oct 31, 2024
0c4568f
feat: add target validator
theissenhelen Oct 31, 2024
c44d231
chore: refactor
theissenhelen Oct 31, 2024
674e2d7
feat: add defaults
theissenhelen Oct 31, 2024
4d3ef0c
feat: add basic graph schemas
theissenhelen Nov 8, 2024
0e319a0
refactor: rename IcosahedralNodeSchema
theissenhelen Nov 8, 2024
076617b
feat: add schema for stretched grid
theissenhelen Nov 8, 2024
d3d3674
fix: import error
theissenhelen Nov 8, 2024
561ec39
feat: adjust to new training schema
theissenhelen Nov 11, 2024
98a16c6
refactor: rename
theissenhelen Nov 11, 2024
c704906
fix: spelling
theissenhelen Nov 11, 2024
2e894c1
refactor: adjust callbacks to pydantic
theissenhelen Nov 11, 2024
0ef7471
chore: put graph schema back in
theissenhelen Nov 11, 2024
f449c36
chore(deps): add pydantic
theissenhelen Nov 11, 2024
61fd8c7
fix: import
theissenhelen Nov 11, 2024
5c3c066
fix: missing imports
theissenhelen Nov 12, 2024
e6c3fdd
fix AnyUrl not supported by omegaconf
theissenhelen Nov 12, 2024
2c7b3a9
feat: merge loss and validation metrics schema
theissenhelen Nov 19, 2024
bde85aa
refactor: pressure level scaling config
theissenhelen Nov 20, 2024
af855f5
feat: add dataloader schema
theissenhelen Nov 21, 2024
b9a1690
feat: adjust datamodule to use dataloader schema
theissenhelen Nov 21, 2024
c9ee0af
feat: make Frequency model compatible
theissenhelen Nov 22, 2024
8b16f64
feat: add benchmarkprofiler schema
theissenhelen Nov 22, 2024
872b3eb
feat: config validate command
theissenhelen Nov 22, 2024
c50aa3a
refactor: replace with enum
theissenhelen Nov 25, 2024
5cdb0be
docs: add description to dataloader schema
Dec 2, 2024
a7d95f3
docs: add description to data schema
Dec 4, 2024
3df5eb2
doc: add docstrings
theissenhelen Dec 4, 2024
895a3c7
refactor: move schemas folder up
theissenhelen Dec 4, 2024
4412a19
chore: add autodoc_pydantic
theissenhelen Dec 4, 2024
de4f32f
doc: dosctrings to hardware schemas
theissenhelen Dec 4, 2024
f7fb892
fix: serialising Enums
theissenhelen Dec 4, 2024
9e5ef60
feat: add http_max_retries to config and basemodel
theissenhelen Dec 5, 2024
86a8a9a
feat: set default value of read_group_size
theissenhelen Dec 5, 2024
a9b64e7
feat: accelerator check moves to pydantic hardware schema
theissenhelen Dec 5, 2024
edaa60e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 5, 2024
091ca0b
refactor: replace target validation with Literal
theissenhelen Dec 5, 2024
f445c1b
feat: replace validators with enums
theissenhelen Dec 5, 2024
5ade973
feat: add model_validator to adjust the learning rate to hardware set…
theissenhelen Dec 5, 2024
844cd24
fix: missing configs
theissenhelen Dec 5, 2024
ee3f362
docs: adjust description format for data and dataloader schema
chebertpinard Dec 9, 2024
a9f0a3c
docs: add description to diagnostics schema
chebertpinard Dec 9, 2024
b6cbab2
Merge branch 'develop' into 1-feature-improved-configuration-and-data…
theissenhelen Dec 10, 2024
e3f2ecf
fix: adjust to changes from develop
theissenhelen Dec 12, 2024
6bd4fac
Merge branch '1-feature-improved-configuration-and-data-structures' o…
theissenhelen Dec 12, 2024
937536f
docs: allow Any plot callbacks in diagnostics
chebertpinard Dec 19, 2024
d7a8d93
docs: add description to models schema
chebertpinard Dec 19, 2024
e46a015
Merge commit 'd7a8d93d0ff49ab258c47cc0252f181a24c79ceb' into 1-featur…
theissenhelen Jan 6, 2025
16310f9
fix: missing grid indices
theissenhelen Jan 9, 2025
8e780dd
fix: missing output mask
theissenhelen Jan 9, 2025
2571f59
fix: now running
theissenhelen Jan 9, 2025
cd89e42
feat: add flexible validators for model and model schema
theissenhelen Jan 9, 2025
3e0be01
chore: pre-commit
theissenhelen Jan 9, 2025
80a5d4e
feat: warning instead of validation error if model not defined in anemoi
theissenhelen Jan 10, 2025
aab3e0b
fix: needed to access graph attribute in config
theissenhelen Jan 10, 2025
5490158
refactor: add type to dataset variable in graph TextNodes
chebertpinard Jan 9, 2025
65839ad
docs: Add descriptions to grid indices.
chebertpinard Jan 10, 2025
5fa2d61
docs: complete description to hardware config schema.
chebertpinard Jan 10, 2025
fe4dae9
docs: complete description to training config schema.
chebertpinard Jan 10, 2025
1297328
fix: add missing graph builder schema and description.
chebertpinard Jan 10, 2025
fc25fa2
Merge branch '1-feature-improved-configuration-and-data-structures' o…
theissenhelen Jan 13, 2025
cf3475f
refactor: streamline code
theissenhelen Jan 13, 2025
26a4340
fix: minor bugs
theissenhelen Jan 13, 2025
2524d78
feat: add pytorch activations
theissenhelen Jan 13, 2025
8425db9
refactor: consistent naming
theissenhelen Jan 13, 2025
6e0913f
fix: naming and plot callback initialisation
theissenhelen Jan 13, 2025
ecb6233
refactor: naming
theissenhelen Jan 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions graphs/src/anemoi/graphs/nodes/builders/from_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,15 +68,15 @@ class TextNodes(BaseNodeBuilder):

Attributes
----------
dataset : str | DictConfig
The path to txt file containing the coordinates of the nodes.
dataset : str | Path
The path including filename to txt file containing the coordinates of the nodes.
idx_lon : int
The index of the longitude in the dataset.
idx_lat : int
The index of the latitude in the dataset.
"""

def __init__(self, dataset, name: str, idx_lon: int = 0, idx_lat: int = 1) -> None:
def __init__(self, dataset: str | Path, name: str, idx_lon: int = 0, idx_lat: int = 1) -> None:
LOGGER.info("Reading the dataset from %s.", dataset)
self.dataset = np.loadtxt(dataset)
self.idx_lon = idx_lon
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def __init__(
Graph definition
"""
super().__init__()

model_config = DotDict(model_config)
self._graph_data = graph_data
self._graph_name_data = model_config.graph.data
self._graph_name_hidden = model_config.graph.hidden
Expand Down
1 change: 1 addition & 0 deletions training/docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
"sphinx.ext.napoleon",
"sphinxarg.ext",
"sphinx.ext.autosectionlabel",
"sphinxcontrib.autodoc_pydantic",
]

# Add any paths that contain templates here, relative to this directory.
Expand Down
2 changes: 2 additions & 0 deletions training/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ dependencies = [
"matplotlib>=3.7.1",
"mlflow>=2.11.1",
"numpy<2", # Pinned until we can confirm it works with anemoi graphs
"pydantic>=2.9",
"pynvml>=11.5",
"pyshtools>=4.10.4",
"pytorch-lightning>=2.1",
Expand All @@ -69,6 +70,7 @@ optional-dependencies.dev = [
]

optional-dependencies.docs = [
"autodoc-pydantic",
"nbsphinx",
"pandoc",
"sphinx",
Expand Down
33 changes: 29 additions & 4 deletions training/src/anemoi/training/commands/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
from pathlib import Path
from typing import TYPE_CHECKING

from hydra import compose
from hydra import initialize

from anemoi.training.commands import Command
from anemoi.training.schemas.base_schema import BaseSchema

if TYPE_CHECKING:
import argparse
Expand Down Expand Up @@ -48,24 +52,39 @@ def add_arguments(command_parser: argparse.ArgumentParser) -> None:
)
anemoi_training_home.add_argument("--overwrite", "-f", action="store_true")

help_msg = "Validate the Anemoi training configs."
validate = subparsers.add_parser("validate", help=help_msg, description=help_msg)

validate.add_argument("--name", help="Name of the primary config file")
validate.add_argument("--overwrite", "-f", action="store_true")

def run(self, args: argparse.Namespace) -> None:
LOGGER.info(
"Generating configs, please wait.",
)

self.overwrite = args.overwrite
if args.subcommand == "generate":

LOGGER.info(
"Generating configs, please wait.",
)
self.traverse_config(args.output)

LOGGER.info("Inference checkpoint saved to %s", args.output)
return

if args.subcommand == "training-home":
anemoi_home = Path.home() / ".config" / "anemoi" / "training" / "config"
LOGGER.info(
"Generating configs, please wait.",
)
self.traverse_config(anemoi_home)
LOGGER.info("Inference checkpoint saved to %s", anemoi_home)
return

if args.subcommand == "validate":
LOGGER.info("Validating configs.")
self.validate_config(args.name)
LOGGER.info("Config files validated.")
return

def traverse_config(self, destination_dir: Path | str) -> None:
"""Writes the given configuration data to the specified file path."""
config_package = "anemoi.training.config"
Expand Down Expand Up @@ -97,5 +116,11 @@ def copy_file(item: Path, file_path: Path) -> None:
except Exception:
LOGGER.exception("Failed to copy %s", item.name)

def validate_config(self, name: Path | str) -> None:
"""Validates the configuration files in the given directory."""
with initialize(version_base=None, config_path=""):
cfg = compose(config_name=name)
cfg = BaseSchema(**cfg)


command = ConfigGenerator
4 changes: 2 additions & 2 deletions training/src/anemoi/training/config/data/zarr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ processors:
# config: ${data.imputer}
normalizer:
_target_: anemoi.models.preprocessing.normalizer.InputNormalizer
_convert_: all
# _convert_: all # Is it still used ???
config: ${data.normalizer}
# remapper:
# _target_: anemoi.models.preprocessing.remapper.Remapper
# _convert_: all
# config: ${data.remapper}

# Values set in the code
# Values set in the code
num_features: null # number of features in the forecast state
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,10 @@ num_workers:
training: 8
validation: 8
test: 8
predict: 8
batch_size:
training: 2
validation: 4
test: 4
predict: 4

# ============
# Default effective batch_size for training is 16
Expand All @@ -38,7 +36,6 @@ limit_batches:
training: null
validation: null
test: 20
predict: 20

# set a custom mask for grid points.
# Useful for LAM (dropping unconnected nodes from forcing dataset)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ log:
on_resume_create_child: True
expand_hyperparams: # Which keys in hyperparams to expand
- config
http_max_retries: 35
interval: 100

enable_progress_bar: True
Expand Down
1 change: 1 addition & 0 deletions training/src/anemoi/training/config/training/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ run_id: null
fork_run_id: null
load_weights_only: null # only load model weights, do not restore optimiser states etc.
transfer_learning: null # activate to perform transfer learning
load_weights_only: False # only load model weights, do not restore optimiser states etc.

# run in deterministic mode ; slows down
deterministic: False
Expand Down
37 changes: 21 additions & 16 deletions training/src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@

import pytorch_lightning as pl
from hydra.utils import instantiate
from omegaconf import DictConfig
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from anemoi.datasets.data import open_dataset
from anemoi.models.data_indices.collection import IndexCollection
from anemoi.training.data.dataset import NativeGridDataset
from anemoi.training.data.dataset import worker_init_func
from anemoi.training.schemas.base_schema import BaseSchema
from anemoi.training.schemas.base_schema import convert_to_omegaconf
from anemoi.utils.dates import frequency_to_seconds

LOGGER = logging.getLogger(__name__)
Expand All @@ -37,12 +37,12 @@
class AnemoiDatasetsDataModule(pl.LightningDataModule):
"""Anemoi Datasets data module for PyTorch Lightning."""

def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
def __init__(self, config: BaseSchema, graph_data: HeteroData) -> None:
"""Initialize Anemoi Datasets data module.

Parameters
----------
config : DictConfig
config : BaseSchema
Job configuration

"""
Expand All @@ -66,7 +66,7 @@ def __init__(self, config: DictConfig, graph_data: HeteroData) -> None:
)
self.config.dataloader.training.end = self.config.dataloader.validation.start - 1

if not self.config.dataloader.get("pin_memory", True):
if not self.config.dataloader.pin_memory:
LOGGER.info("Data loader memory pinning disabled.")

@cached_property
Expand All @@ -83,12 +83,16 @@ def supporting_arrays(self) -> dict:

@cached_property
def data_indices(self) -> IndexCollection:
return IndexCollection(self.config, self.ds_train.name_to_index)
return IndexCollection(convert_to_omegaconf(self.config), self.ds_train.name_to_index)

@cached_property
def grid_indices(self) -> type[BaseGridIndices]:
reader_group_size = self.config.dataloader.get("read_group_size", self.config.hardware.num_gpus_per_model)
grid_indices = instantiate(self.config.dataloader.grid_indices, reader_group_size=reader_group_size)
reader_group_size = self.config.dataloader.read_group_size

grid_indices = instantiate(
self.config.dataloader.grid_indices.model_dump(by_alias=True),
reader_group_size=reader_group_size,
)
grid_indices.setup(self.graph_data)
return grid_indices

Expand Down Expand Up @@ -123,13 +127,14 @@ def timeincrement(self) -> int:
@cached_property
def ds_train(self) -> NativeGridDataset:
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.training, resolve=True)),
open_dataset(self.config.dataloader.training.model_dump()),
label="train",
)

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))
r = self.rollout
r = max(r, self.config.dataloader.validation_rollout)

if not self.config.dataloader.training.end < self.config.dataloader.validation.start:
LOGGER.warning(
Expand All @@ -138,7 +143,7 @@ def ds_valid(self) -> NativeGridDataset:
self.config.dataloader.validation.start,
)
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.validation, resolve=True)),
open_dataset(self.config.dataloader.validation.model_dump()),
shuffle=False,
rollout=r,
label="validation",
Expand All @@ -155,7 +160,7 @@ def ds_test(self) -> NativeGridDataset:
f"test start date {self.config.dataloader.test.start}"
)
return self._get_dataset(
open_dataset(OmegaConf.to_container(self.config.dataloader.test, resolve=True)),
open_dataset(self.config.dataloader.test.model_dump()),
shuffle=False,
label="test",
)
Expand All @@ -172,7 +177,7 @@ def _get_dataset(

# Compute effective batch size
effective_bs = (
self.config.dataloader.batch_size["training"]
self.config.dataloader.batch_size.training
* self.config.hardware.num_gpus_per_node
* self.config.hardware.num_nodes
// self.config.hardware.num_gpus_per_model
Expand All @@ -193,12 +198,12 @@ def _get_dataloader(self, ds: NativeGridDataset, stage: str) -> DataLoader:
assert stage in {"training", "validation", "test"}
return DataLoader(
ds,
batch_size=self.config.dataloader.batch_size[stage],
batch_size=self.config.dataloader.batch_size.model_dump()[stage],
# number of worker processes
num_workers=self.config.dataloader.num_workers[stage],
num_workers=self.config.dataloader.num_workers.model_dump()[stage],
# use of pinned memory can speed up CPU-to-GPU data transfers
# see https://pytorch.org/docs/stable/notes/cuda.html#cuda-memory-pinning
pin_memory=self.config.dataloader.get("pin_memory", True),
pin_memory=self.config.dataloader.pin_memory,
# worker initializer
worker_init_fn=worker_init_func,
# prefetch batches
Expand Down
32 changes: 16 additions & 16 deletions training/src/anemoi/training/diagnostics/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from hydra.utils import instantiate
from omegaconf import DictConfig
from pydantic import BaseModel

from anemoi.training.diagnostics.callbacks.checkpoint import AnemoiCheckpoint
from anemoi.training.diagnostics.callbacks.optimiser import LearningRateMonitor
Expand All @@ -28,21 +29,23 @@
if TYPE_CHECKING:
from pytorch_lightning.callbacks import Callback

from anemoi.training.schemas.base_schema import BaseSchema

LOGGER = logging.getLogger(__name__)


def nestedget(conf: DictConfig, key: str, default: Any) -> Any:
def nestedget(config: DictConfig, key: str, default: Any) -> Any:
"""Get a nested key from a DictConfig object.

E.g.
>>> nestedget(config, "diagnostics.log.wandb.enabled", False)
"""
keys = key.split(".")
for k in keys:
conf = conf.get(k, default)
if not isinstance(conf, (dict, DictConfig)):
config = getattr(config, k, default)
if not isinstance(config, (BaseModel, dict, DictConfig)):
break
return conf
return config


# Callbacks to add according to flags in the config
Expand All @@ -57,9 +60,9 @@ def nestedget(conf: DictConfig, key: str, default: Any) -> Any:
]


def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint]:
def _get_checkpoint_callback(config: BaseSchema) -> list[AnemoiCheckpoint]:
"""Get checkpointing callbacks."""
if not config.diagnostics.get("enable_checkpointing", True):
if not config.diagnostics.enable_checkpointing:
return []

checkpoint_settings = {
Expand All @@ -77,11 +80,11 @@ def _get_checkpoint_callback(config: DictConfig) -> list[AnemoiCheckpoint]:
ckpt_frequency_save_dict = {}

for key, frequency_dict in config.diagnostics.checkpoint.items():
frequency = frequency_dict["save_frequency"]
n_saved = frequency_dict["num_models_saved"]
if key == "every_n_minutes" and frequency_dict["save_frequency"] is not None:
frequency = frequency_dict.save_frequency
n_saved = frequency_dict.num_models_saved
if key == "every_n_minutes" and frequency_dict.save_frequency is not None:
target = "train_time_interval"
frequency = timedelta(minutes=frequency_dict["save_frequency"])
frequency = timedelta(minutes=frequency_dict.save_frequency)
else:
target = key
ckpt_frequency_save_dict[target] = (
Expand Down Expand Up @@ -143,7 +146,7 @@ def check_key(config: dict, key: str | Iterable[str] | Callable[[DictConfig], bo
return callbacks


def get_callbacks(config: DictConfig) -> list[Callback]:
def get_callbacks(config: BaseSchema) -> list[Callback]:
"""Setup callbacks for PyTorch Lightning trainer.

Set `config.diagnostics.callbacks` to a list of callback configurations
Expand Down Expand Up @@ -183,14 +186,11 @@ def get_callbacks(config: DictConfig) -> list[Callback]:
trainer_callbacks.extend(_get_checkpoint_callback(config))

# Base callbacks
trainer_callbacks.extend(
instantiate(callback, config) for callback in config.diagnostics.get("callbacks", None) or []
)
trainer_callbacks.extend(instantiate(callback, config) for callback in config.diagnostics.callbacks)

# Plotting callbacks

trainer_callbacks.extend(
instantiate(callback, config) for callback in config.diagnostics.plot.get("callbacks", None) or []
instantiate(callback.model_dump(by_alias=True), config) for callback in config.diagnostics.plot.callbacks
)

# Extend with config enabled callbacks
Expand Down
Loading
Loading