-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Utility update for generic targets #398
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am quite concerned by the approach of applying building blocks (zbl/composition/llpr) to all outputs of a model, and then automagically removing some outputs inside the building block. It seems that the building blocks would be more usable if the model developers had to explicitly specify on which outputs they building blocks should be applied.
I'll also let @abmazitov review the changes to PET, and @DavideTisi review the changes to GAP as maintainers for these architectures!
src/metatrain/utils/additive/zbl.py
Outdated
@@ -168,7 +170,7 @@ def forward( | |||
# Set the outputs as the ZBL energies | |||
targets_out: Dict[str, TensorMap] = {} | |||
for target_key, target in outputs.items(): | |||
if target_key.startswith("mtt::aux::"): | |||
if is_auxiliary_output(target_key): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this will be sustainable long term. Should we instead change the API of additive building blocks to explicitly take which outputs they should apply to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see it as a form of sparsity. We could return TensorMap
s of zeros, but is that really better? This isn't only happening for these auxiliary
outputs but also for outputs that are not supported by the additive model (e.g. non-scalars for the composition model or non-energies for ZBL). IMO, the issue with having more options and/or output requests is that model writers will have to take care of that, while right now a model writer can just drop the class in
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can improve the documentation though
self.register_buffer( | ||
"covariance", | ||
torch.zeros( | ||
device = next(self.model.parameters()).device |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since these are no longer buffers, will they be automatically moved to new dtype/devices?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, the API could change here as well to explicitly specify which output should we apply LLPR to. I feel like this would be a lot cleaner & more composable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They are sent to the correct device in the forward, the dtype doesn't change (corresponds to that of the model). I didn't find a better way, since torch doesn't support dicts of buffers. This is the same that we do with labels
356e3f0
to
bbe9b16
Compare
bbe9b16
to
0e190c4
Compare
Last part of #364.
Different types of targets makes it necessary to have different types of heads in models. In this PR, heads are made more flexible. A change to the handling of internal features is also needed, as the last-layer features can now be per-target (or, equivalently, per-head). The LLPR module is adapted accordingly.
Finally, the last common representation of the model (before the heads are applied) is exposed as the
features
output (a standard output inmetatensor.torch.atomistic
).Contributor (creator of pull-request) checklist
📚 Documentation preview 📚: https://metatrain--398.org.readthedocs.build/en/398/