Skip to content

Commit

Permalink
An initial version of refurnished documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Jan 15, 2025
1 parent ee3cc7d commit 250d8c6
Show file tree
Hide file tree
Showing 16 changed files with 313 additions and 102 deletions.
1 change: 1 addition & 0 deletions docs/src/references/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,5 @@ refer to the :ref:`userdoc-how-to` section.
metatensor
lib/index
utils/index
tuning/index
changelog
15 changes: 15 additions & 0 deletions docs/src/references/tuning/base_classes.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
Base Classes
############

.. autoclass:: torchpme.tuning.tuner.TunerBase
:members:

.. autoclass:: torchpme.tuning.tuner.GridSearchTuner
:members:

.. autoclass:: torchpme.tuning.tuner.TuningTimings
:members:

.. autoclass:: torchpme.tuning.error_bounds.TuningErrorBounds
:members:

40 changes: 40 additions & 0 deletions docs/src/references/tuning/index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
Tuning
######

The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the
``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the
calculation. To help find the parameters that meet the accuracy requirements, this
module offers tuning methods for the calculators.

The scheme behind all tuning functions is grid-searching based, focusing on the Fourier
space parameters like ``lr_wavelength``, ``mesh_spacing`` and ``interpolation_nodes``.
For real space parameter ``cutoff``, it is treated as a hyperparameter here, which
should be manually specified by the user. The parameter ``smearing`` is determined by
the real space error formula and is set to achieve a real space error of
``desired_accuracy / 4``.

The Fourier space parameters are all discrete, so it's convenient to do the grid-search.
Default searching-ranges are provided for those parameters. For ``lr_wavelength``, the
values are chosen to be with a minimum of 1 and a maximum of 13 mesh points in each
spatial direction ``(x, y, z)``. For ``mesh_spacing``, the values are set to have
minimally 2 and maximally 7 mesh points in each spatial direction, for both the P3M and
PME method. The values of ``interpolation_nodes`` are the same as those supported in
:class:`torchpme.lib.MeshInterpolator`.

In the grid-searching, all possible parameter combinations are evaluated. The error
associated with the parameter is estimated by the error formulas implemented in the
subclasses of :class:`torchpme.tuning.error_bounds.TuningErrorBounds`. Parameter with
the error within the desired accuracy are benchmarked for computational time by
:class:`torchpme.tuning.tuner.TuningTimings` The timing of the other parameters are
not tested and set to infinity.

The return of these tuning functions contains the ``smearing`` and a dictionary, in
which there is parameter for the Fourier space. The parameter is that of the desired
accuracy and the shortest timing. The parameter of the smallest error will be returned
in the case that no parameter can fulfill the accuracy requirement.

.. toctree::
:maxdepth: 1
:glob:

./*
7 changes: 7 additions & 0 deletions docs/src/references/tuning/tune_ewald.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Tune Ewald
##########

.. autofunction:: torchpme.tuning.ewald.tune_ewald

.. autoclass:: torchpme.tuning.error_bounds.EwaldErrorBounds
:members:
7 changes: 7 additions & 0 deletions docs/src/references/tuning/tune_p3m.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Tune P3M
#########

.. autofunction:: torchpme.tuning.p3m.tune_p3m

.. autoclass:: torchpme.tuning.error_bounds.P3MErrorBounds
:members:
7 changes: 7 additions & 0 deletions docs/src/references/tuning/tune_pme.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Tune PME
#########

.. autofunction:: torchpme.tuning.pme.tune_pme

.. autoclass:: torchpme.tuning.error_bounds.PMEErrorBounds
:members:
22 changes: 0 additions & 22 deletions docs/src/references/utils/tuning.rst

This file was deleted.

9 changes: 5 additions & 4 deletions examples/01-charges-example.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
from metatensor.torch.atomistic import NeighborListOptions, System

import torchpme
from torchpme.utils.tuning.pme import PMETuner
from torchpme.tuning import tune_pme

# %%
#
Expand All @@ -57,9 +57,10 @@
# The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge
# of 1 or -1 in units of elementary charges.

smearing, pme_params, cutoff = PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()
cutoff = 4.4
smearing, pme_params = tune_pme(
charges=charges, cell=cell, positions=positions, cutoff=cutoff
)

# %%
#
Expand Down
10 changes: 5 additions & 5 deletions examples/02-neighbor-lists-usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import vesin.torch

import torchpme
from torchpme.utils.tuning.pme import PMETuner
from torchpme.tuning import tune_pme

# %%
#
Expand Down Expand Up @@ -93,10 +93,10 @@
cell = torch.from_numpy(atoms.cell.array)

sum_squared_charges = float(torch.sum(charges**2))

smearing, pme_params, cutoff = PMETuner(
charges=charges, cell=cell, positions=positions, cutoff=4.4
).tune()
cutoff = 4.4
smearing, pme_params = tune_pme(
charges=charges, cell=cell, positions=positions, cutoff=cutoff
)

# %%
#
Expand Down
6 changes: 3 additions & 3 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import vesin.torch as vesin

import torchpme
from torchpme.utils.tuning import TuningTimings
from torchpme.utils.tuning.pme import PMEErrorBounds
from torchpme.tuning.error_bounds import PMEErrorBounds
from torchpme.tuning.tuner import TuningTimings

DTYPE = torch.float64

Expand Down Expand Up @@ -235,7 +235,7 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes):

# %%

EB = torchpme.utils.tuning.pme.PMEErrorBounds((charges**2).sum(), cell, positions)
EB = torchpme.tuning.error_bounds.PMEErrorBounds((charges**2).sum(), cell, positions)

# %%
v, t = timed_madelung(cutoff=5, smearing=1, mesh_spacing=1, interpolation_nodes=4)
Expand Down
2 changes: 1 addition & 1 deletion src/torchpme/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import contextlib

from . import calculators, lib, potentials, utils # noqa
from . import calculators, lib, potentials, tuning, utils # noqa
from ._version import __version__, __version_tuple__ # noqa
from .calculators import Calculator, EwaldCalculator, P3MCalculator, PMECalculator
from .potentials import (
Expand Down
109 changes: 90 additions & 19 deletions src/torchpme/tuning/error_bounds.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,22 @@
import math
import torch

import torch

TWO_PI = 2 * math.pi


class TuningErrorBounds(torch.nn.Module):
"""Base class for error bounds."""
"""
Base class for error bounds. This class calculates the real space error and the
Fourier space error based on the error formula. This class is used in the tuning
process. It can also be used with the :class:`torchpme.tuning.tuner.TunerBase` to
build up a custom parameter tuner.
:param charges: atomic charges
:param cell: single tensor of shape (3, 3), describing the bounding
:param positions: single tensor of shape (``len(charges), 3``) containing the
Cartesian positions of all point charges in the system.
"""

def __init__(
self,
Expand All @@ -21,7 +31,7 @@ def __init__(

def forward(self, *args, **kwargs):
return self.error(*args, **kwargs)


class EwaldErrorBounds(TuningErrorBounds):
r"""
Expand Down Expand Up @@ -60,22 +70,40 @@ def __init__(
self.cell = cell
self.positions = positions

def err_kspace(self, smearing, lr_wavelength):
def err_kspace(
self, smearing: torch.Tensor, lr_wavelength: torch.Tensor
) -> torch.Tensor:
"""
The Fourier space error of Ewald.
:param smearing: see :class:`torchpme.EwaldCalculator` for details
:param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details
"""
smearing = torch.as_tensor(smearing)
lr_wavelength = torch.as_tensor(lr_wavelength)
return (
self.prefac**0.5
/ smearing
/ torch.sqrt(TWO_PI**2 * self.volume / (lr_wavelength) ** 0.5)
* torch.exp(-(TWO_PI**2) * smearing**2 / (lr_wavelength))
)

def err_rspace(self, smearing, cutoff):
def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor:
"""
The real space error of Ewald.
:param smearing: see :class:`torchpme.EwaldCalculator` for details
:param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details
"""
return (
self.prefac
/ torch.sqrt(cutoff * self.volume)
* torch.exp(-(cutoff**2) / 2 / smearing**2)
)

def forward(self, smearing, lr_wavelength, cutoff):
def forward(
self, smearing: float, lr_wavelength: float, cutoff: float
) -> torch.Tensor:
r"""
Calculate the error bound of Ewald.
Expand Down Expand Up @@ -180,7 +208,19 @@ def __init__(
self.cell = cell
self.positions = positions

def err_kspace(self, smearing, mesh_spacing, interpolation_nodes):
def err_kspace(
self,
smearing: torch.Tensor,
mesh_spacing: torch.Tensor,
interpolation_nodes: torch.Tensor,
) -> torch.Tensor:
"""
The Fourier space error of P3M.
:param smearing: see :class:`torchpme.P3MCalculator` for details
:param mesh_spacing: see :class:`torchpme.P3MCalculator` for details
:param interpolation_nodes: see :class:`torchpme.P3MCalculator` for details
"""
actual_spacing = self.cell_dimensions / (
2 * self.cell_dimensions / mesh_spacing + 1
)
Expand All @@ -202,25 +242,33 @@ def err_kspace(self, smearing, mesh_spacing, interpolation_nodes):
)
)

def err_rspace(self, smearing, cutoff):
def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor:
"""
The real space error of P3M.
:param smearing: see :class:`torchpme.P3MCalculator` for details
:param cutoff: see :class:`torchpme.P3MCalculator` for details
"""
return (
self.prefac
/ torch.sqrt(cutoff * self.volume)
* torch.exp(-(cutoff**2) / 2 / smearing**2)
)

def forward(self, smearing, mesh_spacing, cutoff, interpolation_nodes):
def forward(
self,
smearing: float,
mesh_spacing: float,
cutoff: float,
interpolation_nodes: float,
) -> torch.Tensor:
r"""
Calculate the error bound of P3M.
:param smearing: see :class:`torchpme.P3MCalculator` for details
:param mesh_spacing: see :class:`torchpme.P3MCalculator` for details
:param cutoff: see :class:`torchpme.P3MCalculator` for details
:param interpolation_nodes: The number ``n`` of nodes used in the interpolation
per coordinate axis. The total number of interpolation nodes in 3D will be
``n^3``. In general, for ``n`` nodes, the interpolation will be performed by
piecewise polynomials of degree ``n`` (e.g. ``n = 3`` for cubic
interpolation). Only the values ``1, 2, 3, 4, 5`` are supported.
:param interpolation_nodes: see :class:`torchpme.P3MCalculator` for details
"""
smearing = torch.as_tensor(smearing)
mesh_spacing = torch.as_tensor(mesh_spacing)
Expand All @@ -230,7 +278,7 @@ def forward(self, smearing, mesh_spacing, cutoff, interpolation_nodes):
self.err_kspace(smearing, mesh_spacing, interpolation_nodes) ** 2
+ self.err_rspace(smearing, cutoff) ** 2
)


class PMEErrorBounds(TuningErrorBounds):
r"""
Expand Down Expand Up @@ -258,7 +306,19 @@ def __init__(
self.prefac = 2 * self.sum_squared_charges / math.sqrt(len(positions))
self.cell_dimensions = torch.linalg.norm(cell, dim=1)

def err_kspace(self, smearing, mesh_spacing, interpolation_nodes):
def err_kspace(
self,
smearing: torch.Tensor,
mesh_spacing: torch.Tensor,
interpolation_nodes: torch.Tensor,
) -> torch.Tensor:
"""
The Fourier space error of PME.
:param smearing: see :class:`torchpme.PMECalculator` for details
:param mesh_spacing: see :class:`torchpme.PMECalculator` for details
:param interpolation_nodes: see :class:`torchpme.PMECalculator` for details
"""
actual_spacing = self.cell_dimensions / (
2 * self.cell_dimensions / mesh_spacing + 1
)
Expand All @@ -279,7 +339,13 @@ def err_kspace(self, smearing, mesh_spacing, interpolation_nodes):
* RMS_phi[interpolation_nodes - 1]
)

def err_rspace(self, smearing, cutoff):
def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor:
"""
The real space error of PME.
:param smearing: see :class:`torchpme.PMECalculator` for details
:param cutoff: see :class:`torchpme.PMECalculator` for details
"""
smearing = torch.as_tensor(smearing)
cutoff = torch.as_tensor(cutoff)

Expand All @@ -289,7 +355,13 @@ def err_rspace(self, smearing, cutoff):
* torch.exp(-(cutoff**2) / 2 / smearing**2)
)

def error(self, cutoff, smearing, mesh_spacing, interpolation_nodes):
def error(
self,
cutoff: float,
smearing: float,
mesh_spacing: float,
interpolation_nodes: float,
) -> torch.Tensor:
r"""
Calculate the error bound of PME.
Expand All @@ -313,4 +385,3 @@ def error(self, cutoff, smearing, mesh_spacing, interpolation_nodes):
self.err_rspace(smearing, cutoff) ** 2
+ self.err_kspace(smearing, mesh_spacing, interpolation_nodes) ** 2
)

Loading

0 comments on commit 250d8c6

Please sign in to comment.