-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pass last-layer features from PET to metatrain-PET #407
Conversation
@@ -61,6 +61,12 @@ def __init__(self, model_hypers: Dict, dataset_info: DatasetInfo) -> None: | |||
self.is_lora_applied = False | |||
self.checkpoint_path: Optional[str] = None | |||
|
|||
# last-layer feature size (for LLPR module) | |||
self.last_layer_feature_size = ( | |||
self.hypers["N_GNN_LAYERS"] * self.hypers["HEAD_N_NEURONS"] * 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this correct? In the pure pet code, there code for calculating LLFs looks like this
last_layer_features = []
if "central_token" in result.keys():
predictor_output = central_tokens_predictor(
result["central_token"], central_species
)
last_layer_features.append(predictor_output["features"])
if self.USE_BOND_ENERGIES:
predictor_output = messages_bonds_predictor(
output_messages, mask, nums, central_species, multipliers
)
last_layer_features.append(predictor_output["features"])
last_layer_features = torch.concatenate(last_layer_features, dim=1)
last_layer_features = torch_geometric.nn.global_mean_pool(last_layer_features, batch=batch_dict["batch"])
So it's not necessarily multiplied by 2, only if self.USE_BOND_ENERGIES
is True
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot, you're right
|
||
if "mtt::aux::last_layer_features" in outputs: | ||
ll_features = output["last_layer_features"] | ||
print(ll_features.shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a leftover I suppose
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Whooops thanks!
else: | ||
pet_model = PET(ARCHITECTURAL_HYPERS, 0.0, len(all_species)) | ||
pet_model = pet_model.to(device=device, dtype=dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be already merged, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be gone now that I've updated the branch
9f37889
to
f243961
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, I think we can merge, but let's first merge another PET hotfix #408
PET can output last-layer features. This passes them through the metatrain interface so that e.g. exported models can be used in chemiscope (after #398) or the LLPR module
Contributor (creator of pull-request) checklist
📚 Documentation preview 📚: https://metatrain--407.org.readthedocs.build/en/407/