diff --git a/src/metatensor_models/soap_bpnn/model.py b/src/metatensor_models/soap_bpnn/model.py index 999c17dbe..7609a84b5 100644 --- a/src/metatensor_models/soap_bpnn/model.py +++ b/src/metatensor_models/soap_bpnn/model.py @@ -74,6 +74,7 @@ def forward(self, features: TensorMap) -> TensorMap: properties=Labels.range("properties", output_values.shape[-1]), ) ) + return TensorMap(keys=features.keys, blocks=new_blocks) diff --git a/src/metatensor_models/soap_bpnn/tests/test_functionality.py b/src/metatensor_models/soap_bpnn/tests/test_functionality.py new file mode 100644 index 000000000..8e88eb8bf --- /dev/null +++ b/src/metatensor_models/soap_bpnn/tests/test_functionality.py @@ -0,0 +1,25 @@ +import os + +import ase +import rascaline.torch +import torch +import yaml + +from metatensor_models.soap_bpnn import SoapBPNN + + +path = os.path.dirname(__file__) +hypers_path = os.path.join(path, "../default.yml") +dataset_path = os.path.join(path, "data/qm9_reduced_100.xyz") + + +def test_prediction_subset(): + """Tests that the model can predict on a subset + of the elements it was trained on.""" + + all_species = [1, 6, 7, 8, 9] + hypers = yaml.safe_load(open(hypers_path, "r")) + soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64) + + structure = ase.Atoms("O2", positions=[[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + soap_bpnn([rascaline.torch.systems_to_torch(structure)])