-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass last-layer features from PET to metatrain-PET #407
Changes from 3 commits
6b330b4
a41d12c
773a31d
8d17650
f243961
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -61,6 +61,12 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: | |
self.is_lora_applied = False | ||
self.checkpoint_path: Optional[str] = None | ||
|
||
# last-layer feature size (for LLPR module) | ||
self.last_layer_feature_size = ( | ||
self.hypers["N_GNN_LAYERS"] * self.hypers["HEAD_N_NEURONS"] * 2 | ||
) | ||
# times 2 because of the concatenation of the node and edge features | ||
|
||
# additive models: these are handled by the trainer at training | ||
# time, and they are added to the output at evaluation time | ||
additive_models = [] | ||
|
@@ -123,20 +129,44 @@ def forward( | |
output = self.pet(batch) # type: ignore | ||
predictions = output["prediction"] | ||
output_quantities: Dict[str, TensorMap] = {} | ||
|
||
structure_index = batch["batch"] | ||
_, counts = torch.unique(batch["batch"], return_counts=True) | ||
atom_index = torch.cat( | ||
[torch.arange(count, device=predictions.device) for count in counts] | ||
) | ||
samples_values = torch.stack([structure_index, atom_index], dim=1) | ||
samples = Labels(names=["system", "atom"], values=samples_values) | ||
empty_labels = Labels( | ||
names=["_"], values=torch.tensor([[0]], device=predictions.device) | ||
) | ||
|
||
if "mtt::aux::last_layer_features" in outputs: | ||
ll_features = output["last_layer_features"] | ||
print(ll_features.shape) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a leftover I suppose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Whooops thanks! |
||
block = TensorBlock( | ||
values=ll_features, | ||
samples=samples, | ||
components=[], | ||
properties=Labels( | ||
names=["properties"], | ||
values=torch.arange(ll_features.shape[1]).reshape(-1, 1), | ||
), | ||
) | ||
output_tmap = TensorMap( | ||
keys=empty_labels, | ||
blocks=[block], | ||
) | ||
if not outputs["mtt::aux::last_layer_features"].per_atom: | ||
output_tmap = metatensor.torch.sum_over_samples(output_tmap, "atom") | ||
output_quantities["mtt::aux::last_layer_features"] = output_tmap | ||
|
||
for output_name in outputs: | ||
if output_name.startswith("mtt::aux::"): | ||
continue # skip auxiliary outputs (not targets) | ||
energy_labels = Labels( | ||
names=["energy"], values=torch.tensor([[0]], device=predictions.device) | ||
) | ||
empty_labels = Labels( | ||
names=["_"], values=torch.tensor([[0]], device=predictions.device) | ||
) | ||
structure_index = batch["batch"] | ||
_, counts = torch.unique(batch["batch"], return_counts=True) | ||
atom_index = torch.cat( | ||
[torch.arange(count, device=predictions.device) for count in counts] | ||
) | ||
samples_values = torch.stack([structure_index, atom_index], dim=1) | ||
samples = Labels(names=["system", "atom"], values=samples_values) | ||
block = TensorBlock( | ||
samples=samples, | ||
components=[], | ||
|
@@ -216,7 +246,10 @@ def export(self) -> MetatensorAtomisticModel: | |
quantity=self.dataset_info.targets[self.target_name].quantity, | ||
unit=self.dataset_info.targets[self.target_name].unit, | ||
per_atom=False, | ||
) | ||
), | ||
"mtt::aux::last_layer_features": ModelOutput( | ||
unit="unitless", per_atom=True | ||
), | ||
}, | ||
atomic_types=self.atomic_types, | ||
interaction_range=interaction_range, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -256,9 +256,9 @@ def train( | |
else: | ||
pet_model.hypers.TARGET_TYPE = "structural" | ||
pet_model.TARGET_TYPE = "structural" | ||
pet_model = pet_model.to(device=device, dtype=dtype) | ||
else: | ||
pet_model = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)) | ||
pet_model = pet_model.to(device=device, dtype=dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be already merged, right? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be gone now that I've updated the branch |
||
num_params = sum([p.numel() for p in pet_model.parameters()]) | ||
logging.info(f"Number of parameters: {num_params}") | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? In the pure pet code, there code for calculating LLFs looks like this
So it's not necessarily multiplied by 2, only if
self.USE_BOND_ENERGIES
is TrueThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, you're right