diff --git a/src/metatensor_models/soap_bpnn/model.py b/src/metatensor_models/soap_bpnn/model.py index 4eab08d90..999c17dbe 100644 --- a/src/metatensor_models/soap_bpnn/model.py +++ b/src/metatensor_models/soap_bpnn/model.py @@ -53,23 +53,27 @@ def __init__(self, all_species: List[int], hypers: dict) -> None: def forward(self, features: TensorMap) -> TensorMap: # Create a list of the blocks that are present in the features: - present_blocks = [int(key.values.item()) for key in features.keys] + present_blocks = [ + int(features.keys.entry(i).values.item()) + for i in range(features.keys.values.shape[0]) + ] new_blocks: List[TensorBlock] = [] for species_str, network in self.layers.items(): species = int(species_str) if species not in present_blocks: - continue - block = features.block({"species_center": species}) - output_values = network(block.values) - new_blocks.append( - TensorBlock( - values=output_values, - samples=block.samples, - components=block.components, - properties=Labels.range("properties", output_values.shape[-1]), + pass # continue is not accepted by torchscript here + else: + block = features.block({"species_center": species}) + output_values = network(block.values) + new_blocks.append( + TensorBlock( + values=output_values, + samples=block.samples, + components=block.components, + properties=Labels.range("properties", output_values.shape[-1]), + ) ) - ) return TensorMap(keys=features.keys, blocks=new_blocks) diff --git a/src/metatensor_models/soap_bpnn/tests/test_torchscript.py b/src/metatensor_models/soap_bpnn/tests/test_torchscript.py new file mode 100644 index 000000000..6e4529906 --- /dev/null +++ b/src/metatensor_models/soap_bpnn/tests/test_torchscript.py @@ -0,0 +1,20 @@ +import os + +import torch +import yaml + +from metatensor_models.soap_bpnn import SoapBPNN + + +path = os.path.dirname(__file__) +hypers_path = os.path.join(path, "../default.yml") +dataset_path = os.path.join(path, "data/qm9_reduced_100.xyz") + + +def test_torchscript(): + """Tests that the model can be jitted.""" + + all_species = [1, 6, 7, 8, 9] + hypers = yaml.safe_load(open(hypers_path, "r")) + soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64) + torch.jit.script(soap_bpnn)