Skip to content

Commit

Permalink
Add virial dataset, replacing the alchemical dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Jul 23, 2024
1 parent e14207b commit 3c30e09
Show file tree
Hide file tree
Showing 10 changed files with 636 additions and 1,802 deletions.
4 changes: 2 additions & 2 deletions src/metatrain/experimental/alchemical_model/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/qm9_reduced_100.xyz")

ALCHEMICAL_DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/alchemical_reduced_10.xyz"
CARBON_DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz"
)

DEFAULT_HYPERS = get_default_hypers("experimental.alchemical_model")
Expand Down
4 changes: 1 addition & 3 deletions src/metatrain/experimental/pet/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from pathlib import Path

DATASET_PATH = str(
Path(__file__).parents[5] / "tests/resources/alchemical_reduced_10.xyz"
)
DATASET_PATH = str(Path(__file__).parents[5] / "tests/resources/carbon_reduced_100.xyz")
1 change: 1 addition & 0 deletions tests/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

DATASET_PATH_QM9 = RESOURCES_PATH / "qm9_reduced_100.xyz"
DATASET_PATH_ETHANOL = RESOURCES_PATH / "ethanol_reduced_100.xyz"
DATASET_PATH_CARBON = RESOURCES_PATH / "carbon_reduced_100.xyz"
EVAL_OPTIONS_PATH = RESOURCES_PATH / "eval.yaml"
MODEL_PATH = RESOURCES_PATH / "model-32-bit.pt"
MODEL_PATH_64_BIT = RESOURCES_PATH / "model-64-bit.ckpt"
Expand Down
19 changes: 19 additions & 0 deletions tests/cli/test_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from metatrain import RANDOM_SEED
from metatrain.cli.train import train_model
from metatrain.utils.errors import ArchitectureError
from metatrain.utils.logging import setup_logging

from . import (
DATASET_PATH_ETHANOL,
Expand Down Expand Up @@ -510,3 +511,21 @@ def test_train_issue_290(monkeypatch, tmp_path):
options["test_set"] = 0.85

train_model(options)


# def test_train_log_order(caplog, monkeypatch, tmp_path, options):
# """Tests that the log is always printed in the same order for forces
# and virials."""

# caplog.set_level(logging.INFO)
# logger = logging.getLogger()

# with setup_logging(logger, level=logging.INFO):
# logger.info("foo")
# logger.debug("A debug message")

# stdout_log = capsys.readouterr().out

# assert "Logging to file is disabled." not in caplog.text # DEBUG message
# assert_log_entry(stdout_log, loglevel="INFO", message="foo")
# assert "A debug message" not in stdout_log
428 changes: 0 additions & 428 deletions tests/resources/alchemical_reduced_10.xyz

This file was deleted.

600 changes: 600 additions & 0 deletions tests/resources/carbon_reduced_100.xyz

Large diffs are not rendered by default.

1,320 changes: 0 additions & 1,320 deletions tests/resources/carbon_reduced_20.xyz

This file was deleted.

14 changes: 8 additions & 6 deletions tests/utils/data/test_combine_dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,22 @@ def test_without_shuffling():
dataloader_qm9 = DataLoader(dataset, batch_size=10, collate_fn=collate_fn)
# will yield 10 batches of 10

systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz")
systems = read_systems(RESOURCES_PATH / "carbon_reduced_100.xyz")[:10]

conf = {
"mtt::free_energy": {
"quantity": "energy",
"read_from": RESOURCES_PATH / "alchemical_reduced_10.xyz",
"read_from": RESOURCES_PATH / "carbon_reduced_100.xyz",
"reader": "ase",
"key": "free_energy",
"key": "energy",
"unit": "eV",
"forces": False,
"stress": False,
"virial": False,
}
}
targets, _ = read_targets(OmegaConf.create(conf))
targets = {"mtt::free_energy": targets["mtt::free_energy"][:10]}
dataset = Dataset(
{"system": systems, "mtt::free_energy": targets["mtt::free_energy"]}
)
Expand Down Expand Up @@ -99,21 +100,22 @@ def test_with_shuffling():
)
# will yield 10 batches of 10

systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz")
systems = read_systems(RESOURCES_PATH / "carbon_reduced_100.xyz")[:10]

conf = {
"mtt::free_energy": {
"quantity": "energy",
"read_from": RESOURCES_PATH / "alchemical_reduced_10.xyz",
"read_from": RESOURCES_PATH / "carbon_reduced_100.xyz",
"reader": "ase",
"key": "free_energy",
"key": "energy",
"unit": "eV",
"forces": False,
"stress": False,
"virial": False,
}
}
targets, _ = read_targets(OmegaConf.create(conf))
targets = {"mtt::free_energy": targets["mtt::free_energy"][:10]}
dataset = Dataset(
{"system": systems, "mtt::free_energy": targets["mtt::free_energy"]}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/utils/test_evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
def test_evaluate_model(training, exported):
"""Test that the evaluate_model function works as intended."""

systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:2]
systems = read_systems(RESOURCES_PATH / "carbon_reduced_100.xyz")[:2]

atomic_types = set(
torch.unique(torch.concatenate([system.types for system in systems]))
Expand Down
46 changes: 4 additions & 42 deletions tests/utils/test_output_gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,26 +73,7 @@ def test_virial(is_training):

dataset_info = DatasetInfo(
length_unit="angstrom",
atomic_types={
21,
23,
24,
26,
27,
29,
30,
39,
40,
41,
44,
45,
46,
47,
72,
74,
77,
78,
},
atomic_types={6},
targets={
"energy": TargetInfo(
quantity="energy", unit="eV", per_atom=False, gradients=["strain"]
Expand All @@ -102,7 +83,7 @@ def test_virial(is_training):
model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info)
model.to(dtype=torch.float64)

systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:2]
systems = read_systems(RESOURCES_PATH / "carbon_reduced_100.xyz")[:2]

strains = [
torch.eye(
Expand Down Expand Up @@ -161,26 +142,7 @@ def test_both(is_training):
"""Test that the forces and virial are calculated correctly together"""
dataset_info = DatasetInfo(
length_unit="angstrom",
atomic_types={
21,
23,
24,
26,
27,
29,
30,
39,
40,
41,
44,
45,
46,
47,
72,
74,
77,
78,
},
atomic_types={6},
targets={
"energy": TargetInfo(
quantity="energy",
Expand All @@ -193,7 +155,7 @@ def test_both(is_training):
model = __model__(model_hypers=MODEL_HYPERS, dataset_info=dataset_info)
model.to(dtype=torch.float64)

systems = read_systems(RESOURCES_PATH / "alchemical_reduced_10.xyz")[:2]
systems = read_systems(RESOURCES_PATH / "carbon_reduced_100.xyz")[:2]

# Here we re-create strains and systems, otherwise torch
# complains that the graph has already beeen freed in the last grad call
Expand Down

0 comments on commit 3c30e09

Please sign in to comment.