Skip to content

Commit

Permalink
Remove consistency checks (except for tests) (#293)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Jul 18, 2024
1 parent e858dcd commit e2fb721
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 16 deletions.
2 changes: 1 addition & 1 deletion examples/programmatic/llpr/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
selected_atoms=None,
)

outputs = exported_model([ethanol_system], evaluation_options, check_consistency=True)
outputs = exported_model([ethanol_system], evaluation_options, check_consistency=False)
lpr = outputs["mtt::aux::energy_uncertainty"].block().values.detach().cpu().numpy()

# %%
Expand Down
20 changes: 18 additions & 2 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,12 @@ def _add_eval_model_parser(subparser: argparse._SubParsersAction) -> None:
default="output.xyz",
help="filename of the predictions (default: %(default)s)",
)
parser.add_argument(
"--check-consistency",
dest="check_consistency",
action="store_true",
help="whether to run consistency checks (default: %(default)s)",
)


def _prepare_eval_model_args(args: argparse.Namespace) -> None:
Expand Down Expand Up @@ -150,6 +156,7 @@ def _eval_targets(
dataset: Union[Dataset, torch.utils.data.Subset],
options: TargetInfoDict,
return_predictions: bool,
check_consistency: bool = False,
) -> Optional[Dict[str, TensorMap]]:
"""Evaluates an exported model on a dataset and prints the RMSEs for each target.
Optionally, it also returns the predictions of the model.
Expand Down Expand Up @@ -195,7 +202,13 @@ def _eval_targets(
key: value.to(dtype=dtype, device=device)
for key, value in batch_targets.items()
}
batch_predictions = evaluate_model(model, systems, options, is_training=False)
batch_predictions = evaluate_model(
model,
systems,
options,
is_training=False,
check_consistency=check_consistency,
)
batch_predictions = average_by_num_atoms(
batch_predictions, systems, per_structure_keys=[]
)
Expand Down Expand Up @@ -228,6 +241,7 @@ def eval_model(
model: Union[MetatensorAtomisticModel, torch.jit._script.RecursiveScriptModule],
options: DictConfig,
output: Union[Path, str] = "output.xyz",
check_consistency: bool = False,
) -> None:
"""Evaluate an exported model on a given data set.
Expand All @@ -237,7 +251,8 @@ def eval_model(
:param model: Saved model to be evaluated.
:param options: DictConfig to define a test dataset taken for the evaluation.
:param output: Path to save the predicted values
:param output: Path to save the predicted values.
:param check_consistency: Whether to run consistency checks during model evaluation.
"""
logger.info("Setting up evaluation set.")

Expand Down Expand Up @@ -290,6 +305,7 @@ def eval_model(
dataset=eval_dataset,
options=eval_info_dict,
return_predictions=True,
check_consistency=check_consistency,
)
except Exception as e:
raise ArchitectureError(e)
Expand Down
14 changes: 9 additions & 5 deletions src/metatrain/utils/evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def evaluate_model(
systems: List[System],
targets: TargetInfoDict,
is_training: bool,
check_consistency: bool = False,
) -> Dict[str, TensorMap]:
"""
Evaluate the model (in training or exported) on a set of requested targets.
Expand Down Expand Up @@ -75,13 +76,14 @@ def evaluate_model(
system,
positions_grad=len(energy_targets_that_require_position_gradients) > 0,
strain_grad=len(energy_targets_that_require_strain_gradients) > 0,
check_consistency=check_consistency,
)
new_systems.append(new_system)
strains.append(strain)
systems = new_systems

# Based on the keys of the targets, get the outputs of the model:
model_outputs = _get_model_outputs(model, systems, targets)
model_outputs = _get_model_outputs(model, systems, targets, check_consistency)

for energy_target in energy_targets:
# If the energy target requires gradients, compute them:
Expand Down Expand Up @@ -233,6 +235,7 @@ def _get_model_outputs(
],
systems: List[System],
targets: TargetInfoDict,
check_consistency: bool,
) -> Dict[str, TensorMap]:
if is_exported(model):
# put together an EvaluationOptions object
Expand All @@ -245,8 +248,7 @@ def _get_model_outputs(
for key, value in targets.items()
},
)
# we check consistency here because this could be called from eval
return model(systems, options, check_consistency=True)
return model(systems, options, check_consistency=check_consistency)
else:
return model(
systems,
Expand All @@ -259,7 +261,9 @@ def _get_model_outputs(
)


def _prepare_system(system: System, positions_grad: bool, strain_grad: bool):
def _prepare_system(
system: System, positions_grad: bool, strain_grad: bool, check_consistency: bool
):
"""
Prepares a system for gradient calculation.
"""
Expand Down Expand Up @@ -294,7 +298,7 @@ def _prepare_system(system: System, positions_grad: bool, strain_grad: bool):
for nl_options in system.known_neighbor_lists():
nl = system.get_neighbor_list(nl_options)
nl = metatensor.torch.detach_block(nl)
register_autograd_neighbors(new_system, nl, check_consistency=True)
register_autograd_neighbors(new_system, nl, check_consistency)
new_system.add_neighbor_list(nl_options, nl)

return new_system, strain
10 changes: 3 additions & 7 deletions src/metatrain/utils/llpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def forward(
outputs=outputs,
selected_atoms=selected_atoms,
)
return self.model(systems, options, check_consistency=True)
return self.model(systems, options, check_consistency=False)

per_atom_all_targets = [output.per_atom for output in outputs.values()]
# impose either all per atom or all not per atom
Expand Down Expand Up @@ -130,9 +130,7 @@ def forward(
outputs=outputs_for_model,
selected_atoms=selected_atoms,
)
return_dict = self.model(
systems, options, check_consistency=True
) # TODO: True or False here?
return_dict = self.model(systems, options, check_consistency=False)

ll_features = return_dict["mtt::aux::last_layer_features"]

Expand Down Expand Up @@ -248,9 +246,7 @@ class in ``metatrain``.
length_unit="",
outputs=outputs,
)
output = self.model(
systems, options, check_consistency=True
) # TODO: True or False here?
output = self.model(systems, options, check_consistency=False)
ll_feat_tmap = output["mtt::aux::last_layer_features"]
ll_feats = ll_feat_tmap.block().values / n_atoms.unsqueeze(1)
self.covariance += ll_feats.T @ ll_feats
Expand Down
5 changes: 5 additions & 0 deletions tests/cli/test_eval_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_eval_cli(monkeypatch, tmp_path):
str(EVAL_OPTIONS_PATH),
"-e",
str(RESOURCES_PATH / "extensions"),
"--check-consistency",
]

output = subprocess.check_output(command, stderr=subprocess.STDOUT)
Expand All @@ -60,6 +61,7 @@ def test_eval(monkeypatch, tmp_path, caplog, model_name, options):
model=model,
options=options,
output="foo.xyz",
check_consistency=True,
)

# Test target predictions
Expand Down Expand Up @@ -94,6 +96,7 @@ def test_eval_export(monkeypatch, tmp_path, options):
model=exported_model,
options=options,
output="foo.xyz",
check_consistency=True,
)


Expand All @@ -108,6 +111,7 @@ def test_eval_multi_dataset(monkeypatch, tmp_path, caplog, model, options):
model=model,
options=OmegaConf.create([options, options]),
output="foo.xyz",
check_consistency=True,
)

# Test target predictions
Expand All @@ -131,6 +135,7 @@ def test_eval_no_targets(monkeypatch, tmp_path, model, options):
eval_model(
model=model,
options=options,
check_consistency=True,
)

assert Path("output.xyz").is_file()
4 changes: 3 additions & 1 deletion tests/utils/test_evaluate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def test_evaluate_model(training, exported):
]

systems = [system.to(torch.float32) for system in systems]
outputs = evaluate_model(model, systems, targets, is_training=training)
outputs = evaluate_model(
model, systems, targets, is_training=training, check_consistency=True
)

assert isinstance(outputs, dict)
assert "energy" in outputs
Expand Down

0 comments on commit e2fb721

Please sign in to comment.