Skip to content

Commit

Permalink
Make the SOAP-BPNN torch-scriptable
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Dec 1, 2023
1 parent d92ad11 commit 9c52372
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/metatensor_models/soap_bpnn/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,23 +53,27 @@ def __init__(self, all_species: List[int], hypers: dict) -> None:

def forward(self, features: TensorMap) -> TensorMap:
# Create a list of the blocks that are present in the features:
present_blocks = [int(key.values.item()) for key in features.keys]
present_blocks = [
int(features.keys.entry(i).values.item())
for i in range(features.keys.values.shape[0])
]

new_blocks: List[TensorBlock] = []
for species_str, network in self.layers.items():
species = int(species_str)
if species not in present_blocks:
continue
block = features.block({"species_center": species})
output_values = network(block.values)
new_blocks.append(
TensorBlock(
values=output_values,
samples=block.samples,
components=block.components,
properties=Labels.range("properties", output_values.shape[-1]),
pass # continue is not accepted by torchscript here
else:
block = features.block({"species_center": species})
output_values = network(block.values)
new_blocks.append(
TensorBlock(
values=output_values,
samples=block.samples,
components=block.components,
properties=Labels.range("properties", output_values.shape[-1]),
)
)
)
return TensorMap(keys=features.keys, blocks=new_blocks)


Expand Down
20 changes: 20 additions & 0 deletions src/metatensor_models/soap_bpnn/tests/test_torchscript.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
import os

import torch
import yaml

from metatensor_models.soap_bpnn import SoapBPNN


path = os.path.dirname(__file__)
hypers_path = os.path.join(path, "../default.yml")
dataset_path = os.path.join(path, "data/qm9_reduced_100.xyz")


def test_torchscript():
"""Tests that the model can be jitted."""

all_species = [1, 6, 7, 8, 9]
hypers = yaml.safe_load(open(hypers_path, "r"))
soap_bpnn = SoapBPNN(all_species, hypers).to(torch.float64)
torch.jit.script(soap_bpnn)

0 comments on commit 9c52372

Please sign in to comment.