Skip to content

Commit

Permalink
Allow spherical tensor targets for nanoPET (#402)
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Dec 17, 2024
1 parent 7ca524b commit 71d7ef5
Show file tree
Hide file tree
Showing 18 changed files with 2,172 additions and 173 deletions.
2 changes: 1 addition & 1 deletion docs/src/advanced-concepts/fitting-generic-targets.rst
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ capabilities of the architectures in metatrain.
* - NanoPET
- Energy, forces, stress/virial
- Yes
- No
- Yes
- Only with ``rank=1`` (vectors)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,4 @@ def test_regression_train():
# torch.set_printoptions(precision=12)
# print(output["mtt::U0"].block().values)

torch.testing.assert_close(
output["mtt::U0"].block().values, expected_output, rtol=0, atol=0
)
torch.testing.assert_close(output["mtt::U0"].block().values, expected_output)
79 changes: 51 additions & 28 deletions src/metatrain/experimental/nanopet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,11 @@ class NanoPET(torch.nn.Module):
__supported_devices__ = ["cuda", "cpu"]
__supported_dtypes__ = [torch.float64, torch.float32]

component_labels: Dict[str, List[Labels]]
component_labels: Dict[str, List[List[Labels]]]

def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
super().__init__()

for target in dataset_info.targets.values():
if target.is_spherical:
raise ValueError(
"The NanoPET model does not support spherical tensor targets. "
"Only scalar and Cartesian tensor targets are supported."
)
# checks on targets inside the RotationalAugmenter class in the trainer

self.hypers = model_hypers
self.dataset_info = dataset_info
Expand Down Expand Up @@ -125,10 +119,10 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None:
self.heads = torch.nn.ModuleDict()
self.head_types = self.hypers["heads"]
self.last_layers = torch.nn.ModuleDict()
self.output_shapes: Dict[str, List[int]] = {}
self.output_shapes: Dict[str, List[List[int]]] = {}
self.key_labels: Dict[str, Labels] = {}
self.component_labels: Dict[str, List[Labels]] = {}
self.property_labels: Dict[str, Labels] = {}
self.component_labels: Dict[str, List[List[Labels]]] = {}
self.property_labels: Dict[str, List[Labels]] = {}
for target_name, target_info in dataset_info.targets.items():
self._add_output(target_name, target_info)

Expand Down Expand Up @@ -215,12 +209,15 @@ def forward(
for output_name, label in self.key_labels.items()
}
self.component_labels = {
output_name: [label.to(device) for label in labels]
for output_name, labels in self.component_labels.items()
output_name: [
[labels.to(device) for labels in components_block]
for components_block in components_tmap
]
for output_name, components_tmap in self.component_labels.items()
}
self.property_labels = {
output_name: label.to(device)
for output_name, label in self.property_labels.items()
output_name: [labels.to(device) for labels in properties_tmap]
for output_name, properties_tmap in self.property_labels.items()
}

system_indices = torch.concatenate(
Expand Down Expand Up @@ -430,17 +427,28 @@ def forward(
if output_name in outputs:
atomic_features = atomic_features_dict[output_name]
atomic_properties = last_layer(atomic_features)
block = TensorBlock(
values=atomic_properties.reshape(
[-1] + self.output_shapes[output_name]
),
samples=sample_labels,
components=self.component_labels[output_name],
properties=self.property_labels[output_name],
split_atomic_properties_by_block = torch.split(
atomic_properties,
[manual_prod(shape) for shape in self.output_shapes[output_name]],
dim=-1,
)
blocks = [
TensorBlock(
values=atomic_property.reshape([-1] + shape),
samples=sample_labels,
components=components,
properties=properties,
)
for atomic_property, shape, components, properties in zip(
split_atomic_properties_by_block,
self.output_shapes[output_name],
self.component_labels[output_name],
self.property_labels[output_name],
)
]
atomic_properties_tmap_dict[output_name] = TensorMap(
keys=self.key_labels[output_name],
blocks=[block],
blocks=blocks,
)

if selected_atoms is not None:
Expand Down Expand Up @@ -529,9 +537,12 @@ def export(self) -> MetatensorAtomisticModel:

def _add_output(self, target_name: str, target_info: TargetInfo) -> None:

# one output shape for each tensor block
self.output_shapes[target_name] = [
len(comp.values) for comp in target_info.layout.block().components
] + [len(target_info.layout.block().properties.values)]
[len(comp.values) for comp in block.components]
+ [len(block.properties.values)]
for block in target_info.layout.blocks()
]
self.outputs[target_name] = ModelOutput(
quantity=target_info.quantity,
unit=target_info.unit,
Expand All @@ -557,10 +568,22 @@ def _add_output(self, target_name: str, target_info: TargetInfo) -> None:

self.last_layers[target_name] = torch.nn.Linear(
self.hypers["d_pet"],
prod(self.output_shapes[target_name]),
sum(prod(shape) for shape in self.output_shapes[target_name]),
bias=False,
)

self.key_labels[target_name] = target_info.layout.keys
self.component_labels[target_name] = target_info.layout.block().components
self.property_labels[target_name] = target_info.layout.block().properties
self.component_labels[target_name] = [
block.components for block in target_info.layout.blocks()
]
self.property_labels[target_name] = [
block.properties for block in target_info.layout.blocks()
]


def manual_prod(shape: List[int]) -> int:
# prod from standard library not supported in torchscript
result = 1
for dim in shape:
result *= dim
return result
Loading

0 comments on commit 71d7ef5

Please sign in to comment.