Skip to content

Commit

Permalink
Replace trainer function with Trainer class (#185)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: frostedoyster <[email protected]>
Co-authored-by: Filippo Bigi <[email protected]>
Co-authored-by: Arslan Mazitov <[email protected]>
  • Loading branch information
4 people committed May 29, 2024
1 parent 3e8bc4e commit 19451a6
Show file tree
Hide file tree
Showing 79 changed files with 2,445 additions and 2,225 deletions.
6 changes: 0 additions & 6 deletions docs/src/dev-docs/adding-models.rst

This file was deleted.

2 changes: 1 addition & 1 deletion docs/src/dev-docs/architecture-life-cycle.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ repository. To qualify as an experimental architecture, certain criteria must be
a public git repository or another public URL with a repository is acceptable.

For detailed instructions on adding a new architecture, refer to
:ref:`adding-new-models`.
:ref:`adding-new-architecture`.

Stable Architectures
--------------------
Expand Down
2 changes: 1 addition & 1 deletion docs/src/dev-docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ module.
.. toctree::
:maxdepth: 1

adding-models
architecture-life-cycle
new-architecture
cli/index
utils/index
144 changes: 144 additions & 0 deletions docs/src/dev-docs/new-architecture.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
.. _adding-new-architecture:

Adding a new architecture
=========================

To work with` metatensor-models` any architecture has to follow the same public API to
be called correctly within the :py:func:`metatensor.models.cli.train` function to
process the user's options. In brief, the core of the ``train`` function looks similar
to these lines

.. code-block:: python
from architecture import __model__ as Model
from architecture import __trainer__ as Trainer
hypers = {}
dataset_info = DatasetInfo()
if "continue_from":
model = Model.load_checkpoint("path")
model = model.restart(dataset_info)
else:
model = Model(hypers["architecture"], dataset_info)
trainer = Trainer(hypers["training"])
trainer.train(
model=model,
devices=[],
train_datasets=[],
validation_datasets=[],
checkpoint_dir="path",
)
model.save_checkpoint("final.ckpt")
mts_atomistic_model = model.export()
mts_atomistic_model.export("path", collect_extensions="extensions-dir/")
In order to follow this, a new architectures has two define two classes

- a ``Model`` class, defining the core of the architecture. This class must implement
the interface documented below in :py:class:`ModelInterface`
- a ``Trainer`` class, used to train an architecture and produce a model that can be
evaluated and exported. This class must implement the interface documented below in
:py:class:`TrainerInterface`.

The ``ModelInterface`` is the main model class and must implement a
``save_checkpoint()``, ``load_checkpoint()`` as well as a ``restart()`` and
``export()`` method.

.. code-block:: python
class ModelInterface:
__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float64, torch.float32]
def __init__(self, model_hypers, dataset_info: DatasetInfo):
self.hypers = model_hypers
self.dataset_info = dataset_info
def save_checkpoint(self, path: Union[str, Path]):
pass
@classmethod
def load_checkpoint(cls, path: Union[str, Path]) -> "ModelInterface":
pass
def restart(cls, dataset_info: DatasetInfo) -> "ModelInterface":
"""Restart training.
This function is called whenever training restarts, with the same or a
different dataset.
It enables transfer learning (changing the targets), and fine tuning (same
targets, different dataset)
"""
pass
def export(self) -> MetatensorAtomisticModel:
pass
Note that the ``ModelInterface`` does not necessary inherit from
:py:class:`torch.nn.Module` since training can be performed in any way.
``__supported_devices__`` and ``__supported_dtypes__`` can be defined to set the
capabilities of the model. These two lists should be sorted in order of preference since
`metatensor-models` will use these to determine, based on the user request and
machines's availability, the optimal `dtype` and `device` for training.

The ``export()`` method is required to transform a trained model into a standalone file
to be used in combination with molecular dynamic engines to run simulations. We provide
a helper function :py:func:`metatensor.models.utils.export.export` to export a torch
model to an :py:class:`MetatensorAtomisticModel
<metatensor.torch.atomistic.MetatensorAtomisticModel>`.

The ``TrainerInterface`` class should have the following signature with a required
methods for ``train()``.

.. code-block:: python
class TrainerInterface:
def __init__(self, train_hypers):
self.hypers = train_hypers
def train(
self,
model: ModelInterface,
devices: List[torch.device],
train_datasets: List[Union[Dataset, torch.utils.data.Subset]],
validation_datasets: List[Union[Dataset, torch.utils.data.Subset]],
checkpoint_dir: str,
): ...
The names of the ``ModelInterface`` and the ``TrainerInterface`` are free to choose but
should be linked to constants in the ``__init__.py`` of each architecture. On top of
these two constants the ``__init__.py`` must contain constants for the original
`__authors__` and current `__maintainers__` of the architecture.

.. code-block:: python
from .model import CustomSOTAModel
from .trainer import Trainer
__model__ = CustomSOTAModel
__trainer__ = Trainer
__authors__ = [
("Jane Roe <[email protected]>", "@janeroe"),
("John Doe <[email protected]>", "@johndoe"),
]
__maintainers__ = [("Joe Bloggs <[email protected]>", "@joebloggs")]
:param __model__: Mapping of the custom ``ModelInterface`` to a general one to be loaded
by metatensor-models
:param __trainer__: Same as ``__MODEL_CLASS__`` but the Trainer class.
:param __authors__: Tuple denoting the original authors with email address and Github
handle of an architecture. These do not necessary be in charge of maintaining the
the architecture
:param __maintainers__: Tuple denoting the current maintainers of the architecture. Uses
the same style as the ``__authors__`` constant.
7 changes: 7 additions & 0 deletions docs/src/dev-docs/utils/dtype.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Dtype
#####

.. automodule:: metatensor.models.utils.dtype
:members:
:undoc-members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/src/dev-docs/utils/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@ This is the API for the ``utils`` module of ``metatensor-models``.
architectures
composition
devices
dtype
errors
evaluate_model
external_naming
export
io
logging
loss
merge_capabilities
metrics
neighbor_lists
omegaconf
Expand Down
7 changes: 0 additions & 7 deletions docs/src/dev-docs/utils/merge_capabilities.rst

This file was deleted.

Binary file added examples/ase/model.pt
Binary file not shown.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ alchemical-model = [
"torch_alchemical @ git+https://github.com/abmazitov/torch_alchemical.git@51ff519",
]
pet = [
"pet @ git+https://github.com/spozdn/pet.git@ad3dc8a",
"pet @ git+https://github.com/spozdn/pet.git@9f6119d",
]

[tool.setuptools.packages.find]
Expand Down
42 changes: 36 additions & 6 deletions src/metatensor/models/__main__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,40 @@
"""The main entry point for the metatensor-models command line interface."""

import argparse
import importlib
import logging
import os
import sys
import traceback
import warnings
from datetime import datetime
from pathlib import Path

import metatensor.torch
from omegaconf import OmegaConf

from . import __version__
from .cli.eval import _add_eval_model_parser, eval_model
from .cli.export import _add_export_model_parser, export_model
from .cli.train import _add_train_model_parser, train_model
from .utils.architectures import check_architecture_name
from .utils.logging import setup_logging


# This import is necessary to avoid errors when loading an
# exported alchemical model, which depends on sphericart-torch.
# TODO: Remove this when https://github.com/lab-cosmo/metatensor/issues/512
# is ready
try:
import sphericart.torch # noqa: F401
except ImportError:
pass

try:
import rascaline.torch # noqa: F401
except ImportError:
pass


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -69,14 +86,27 @@ def main():
args = ap.parse_args()
callable = args.__dict__.pop("callable")
debug = args.__dict__.pop("debug")
logfile = None

if debug:
level = logging.DEBUG
else:
level = logging.INFO
warnings.filterwarnings("ignore") # ignore all warnings if not in debug mode

if callable == "train_model":
if callable == "eval_model":
args.__dict__["model"] = metatensor.torch.atomistic.load_atomistic_model(
path=args.__dict__.pop("path"),
extensions_directory=args.__dict__.pop("extensions_directory"),
)
elif callable == "export_model":
architecture_name = args.__dict__.pop("architecture_name")
check_architecture_name(architecture_name)
architecture = importlib.import_module(f"metatensor.models.{architecture_name}")

args.__dict__["model"] = architecture.__model__.load_checkpoint(
args.__dict__.pop("path")
)
elif callable == "train_model":
# define and create `checkpoint_dir` based on current directory and date/time
checkpoint_dir = _datetime_output_path(now=datetime.now())
os.makedirs(checkpoint_dir)
Expand All @@ -92,7 +122,7 @@ def main():

args.options = OmegaConf.merge(args.options, override_options)
else:
logfile = None
raise ValueError("internal error when selecting a sub-command.")

with setup_logging(logger, logfile=logfile, level=level):
try:
Expand All @@ -104,11 +134,11 @@ def main():
train_model(**args.__dict__)
else:
raise ValueError("internal error when selecting a sub-command.")
except Exception as e:
except Exception as err:
if debug:
traceback.print_exc()
else:
sys.exit(f"\033[31mERROR: {e}\033[0m") # format error in red!
sys.exit(str(err))


if __name__ == "__main__":
Expand Down
33 changes: 20 additions & 13 deletions src/metatensor/models/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
write_predictions,
)
from ..utils.errors import ArchitectureError
from ..utils.evaluate_model import evaluate_model
from ..utils.export import is_exported
from ..utils.io import load
from ..utils.evaluate_model import _get_outputs, evaluate_model
from ..utils.logging import MetricLogger
from ..utils.metrics import RMSEAccumulator
from ..utils.neighbor_lists import get_system_with_neighbor_lists
Expand Down Expand Up @@ -49,15 +47,27 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
)
parser.set_defaults(callable="eval_model")
parser.add_argument(
"model",
type=load,
"path",
type=str,
help="Saved exported model to be evaluated.",
)
parser.add_argument(
"options",
type=OmegaConf.load,
help="Eval options file to define a dataset for evaluation.",
)
parser.add_argument(
"-e",
"--extdir",
type=str,
required=False,
dest="extensions_directory",
default=None,
help=(
"path to a directory containing all extensions required by the exported "
"model"
),
)
parser.add_argument(
"-o",
"--output",
Expand Down Expand Up @@ -186,7 +196,8 @@ def _eval_targets(
rmse_values = rmse_accumulator.finalize(not_per_atom=["positions_gradients"])
# print the RMSEs with MetricLogger
metric_logger = MetricLogger(
model_capabilities=model.capabilities(),
logobj=logger,
model_outputs=_get_outputs(model),
initial_metrics=rmse_values,
)
metric_logger.log(rmse_values)
Expand All @@ -200,7 +211,9 @@ def _eval_targets(


def eval_model(
model: torch.nn.Module, options: DictConfig, output: Union[Path, str] = "output.xyz"
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
options: DictConfig,
output: Union[Path, str] = "output.xyz",
) -> None:
"""Evaluate an exported model on a given data set.
Expand All @@ -212,12 +225,6 @@ def eval_model(
:param options: DictConfig to define a test dataset taken for the evaluation.
:param output: Path to save the predicted values
"""
if not is_exported(model):
raise ValueError(
"The model must already be exported to be used in `eval`. "
"If you are trying to evaluate a checkpoint, export it first "
"with the `metatensor-models export` command."
)
logger.info("Setting up evaluation set.")

# TODO: once https://github.com/lab-cosmo/metatensor/pull/551 is merged and released
Expand Down
Loading

0 comments on commit 19451a6

Please sign in to comment.