Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revamped tuning #130

Open
wants to merge 36 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
2466ea5
Initial version of `grid_search`
GardevoirX Nov 26, 2024
0b35dbc
Remove error
GardevoirX Nov 26, 2024
fe2c8a5
Allow a precomputed nl
GardevoirX Nov 26, 2024
0ca11de
Renamed examples, and added a tuning playground
ceriottm Nov 23, 2024
d72f818
Nelder mead (doesn't work because actual error is not a good target)
ceriottm Nov 24, 2024
cedfe47
Added a tuning class
ceriottm Nov 24, 2024
6301dcb
I'm not a morning person it seems
ceriottm Nov 24, 2024
e37145a
Examples
ceriottm Nov 24, 2024
e4eb476
Better plotting
ceriottm Nov 24, 2024
8d7c3e0
Fixes on `H` and `RMS_phi`
GardevoirX Nov 25, 2024
33b479e
Some cleaning and test fix
GardevoirX Nov 25, 2024
5e70197
Further clean
GardevoirX Nov 26, 2024
f2aa91a
Replace `loss` in tuning with `ErrorBounds` and draft for `Tuner`
GardevoirX Nov 27, 2024
816121a
Supress output
GardevoirX Nov 27, 2024
38a4705
Update `grid_search`
GardevoirX Nov 28, 2024
22c2c2c
Return something when is cannot reach desired accuracy
GardevoirX Nov 28, 2024
8f124b3
Supress output
GardevoirX Nov 28, 2024
687098e
Repair some errors of the example
GardevoirX Nov 28, 2024
69631d4
Add a warning for the case that no parameter can meet the accuracy re…
GardevoirX Dec 5, 2024
b7a8ad6
Update warning
GardevoirX Dec 5, 2024
666bf7b
Documentations and pytests update
GardevoirX Dec 18, 2024
d603d16
Added a TIP4P example
ceriottm Dec 20, 2024
647d697
Started to change the API to use full charges rather than the sum of …
ceriottm Dec 20, 2024
3510737
Move from `sum_squared_charges` to `charges`
GardevoirX Dec 28, 2024
308b281
Refactor the tuning methods with a base class
GardevoirX Dec 28, 2024
eb9290c
Fix pytests and make linter happy
GardevoirX Dec 28, 2024
fe370e0
Mini cleanups
ceriottm Dec 29, 2024
dd19651
Docs fix
GardevoirX Dec 29, 2024
ec702cb
Separate timings calculator
ceriottm Dec 29, 2024
d82d219
Linting
ceriottm Dec 29, 2024
8806505
Try fix github action failures
GardevoirX Dec 29, 2024
d14c72e
Add tuning functions back
GardevoirX Jan 7, 2025
56fc382
Allow doctests
GardevoirX Jan 7, 2025
33c9705
Fix doctests and remove orphan functions
GardevoirX Jan 7, 2025
e5644b4
Fix ewald doctest again and remove unused members
GardevoirX Jan 7, 2025
b617d0a
Formatting
GardevoirX Jan 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions docs/src/references/utils/tuning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@ than the given accuracy. Because these methods are gradient-based, be sure to pa
attention to the ``learning_rate`` and ``max_steps`` parameter. A good choice of these
two parameters can enhance the optimization speed and performance.

.. autoclass:: torchpme.utils.tune_ewald
.. autoclass:: torchpme.utils.tuning.ewald.EwaldTuner
GardevoirX marked this conversation as resolved.
Show resolved Hide resolved
:members:

.. autoclass:: torchpme.utils.tune_pme
.. autoclass:: torchpme.utils.tuning.pme.PMETuner
:members:

.. autoclass:: torchpme.utils.tuning.p3m.P3MTuner
:members:
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@
from metatensor.torch.atomistic import NeighborListOptions, System

import torchpme
from torchpme.utils.tuning.pme import PMETuner

# %%
#
# Create the properties CsCl unit cell

symbols = ("Cs", "Cl")
types = torch.tensor([55, 17])
charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64)
positions = torch.tensor([(0, 0, 0), (0.5, 0.5, 0.5)], dtype=torch.float64)
cell = torch.eye(3, dtype=torch.float64)
pbc = torch.tensor([True, True, True])
Expand All @@ -55,9 +57,9 @@
# The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge
# of 1 or -1 in units of elementary charges.

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=2.0, cell=cell, positions=positions
)
smearing, pme_params, cutoff = PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()

# %%
#
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import vesin.torch

import torchpme
from torchpme.utils.tuning.pme import PMETuner

# %%
#
Expand Down Expand Up @@ -93,9 +94,9 @@

sum_squared_charges = float(torch.sum(charges**2))

smearing, pme_params, cutoff = torchpme.utils.tune_pme(
sum_squared_charges=sum_squared_charges, cell=cell, positions=positions
)
smearing, pme_params, cutoff = PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()

# %%
#
Expand Down
File renamed without changes.
File renamed without changes.
86 changes: 84 additions & 2 deletions examples/5-autograd-demo.py → examples/05-autograd-demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
exercise to the reader.
"""

# %%

from time import time

import ase
Expand Down Expand Up @@ -477,10 +479,11 @@ def forward(self, positions, cell, charges):
)

# %%
# We can also time the difference in execution
# We can also evaluate the difference in execution
# time between the Pytorch and scripted versions of the
# module (depending on the system, the relative efficiency
# of the two evaluations could go either way!)
# of the two evaluations could go either way, as this is
# a too small system to make a difference!)

duration = 0.0
for _i in range(20):
Expand Down Expand Up @@ -515,3 +518,82 @@ def forward(self, positions, cell, charges):
print(f"Evaluation time:\nPytorch: {time_python}ms\nJitted: {time_jit}ms")

# %%
# Other auto-differentiation ideas
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMHO opinion I wouldn't put this example here - even though I think it is good to have it. The tutorial is already 500 lines and with this super long. I rather vote for smaller examples tackling one specific tasks. Finding solutions is much easier if they are shorter. See also the beloved matplotlib examples.

# --------------------------------
#
# There are many other ways the auto-differentiation engine of
# ``torch`` can be used to facilitate the evaluation of atomistic
# models.

# %%
# 4-site water models
# ~~~~~~~~~~~~~~~~~~~
#
# Several water models (starting from the venerable TIP4P model of
# `Abascal and C. Vega, JCP (2005) <http://doi.org/10.1063/1.2121687>`_)
# use a center of negative charge that is displaced from the O position.
# This is easily implemented, yielding the forces on the O and H positions
# generated by the displaced charge.

structure = ase.Atoms(
positions=[
[0, 0, 0],
[0, 1, 0],
[1, -0.2, 0],
],
cell=[6, 6, 6],
symbols="OHH",
)

cell = torch.from_numpy(structure.cell.array).to(device=device, dtype=dtype)
positions = torch.from_numpy(structure.positions).to(device=device, dtype=dtype)

# %%
# The key step is to create a "fourth site" based on the O positions
# and use it in the ``interpolate`` step.

charges = torch.tensor(
[[-1.0], [0.5], [0.5]],
dtype=dtype,
device=device,
)

positions.requires_grad_(True)
charges.requires_grad_(True)
cell.requires_grad_(True)

positions_4site = torch.vstack(
[
((positions[1::3] + positions[2::3]) * 0.5 + positions[0::3] * 3) / 4,
positions[1::3],
positions[2::3],
]
)

ns = torch.tensor([5, 5, 5])
interpolator = torchpme.lib.MeshInterpolator(
cell=cell, ns_mesh=ns, interpolation_nodes=3, method="Lagrange"
)
interpolator.compute_weights(positions_4site)
mesh = interpolator.points_to_mesh(charges)

value = (mesh**2).sum()

# %%
# The gradients can be computed by just running `backward` on the
# end result. Gradients are computed on the H and O positions.

value.backward()

print(
f"""
Position gradients:
{positions.grad.T}

Cell gradients:
{cell.grad}

Charges gradients:
{charges.grad.T}
"""
)
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Loading
Loading