From 250d8c692add4865b44fca61671f1c8055007377 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Thu, 16 Jan 2025 00:17:43 +0100 Subject: [PATCH] An initial version of refurnished documentation --- docs/src/references/index.rst | 1 + docs/src/references/tuning/base_classes.rst | 15 +++ docs/src/references/tuning/index.rst | 40 ++++++ docs/src/references/tuning/tune_ewald.rst | 7 ++ docs/src/references/tuning/tune_p3m.rst | 7 ++ docs/src/references/tuning/tune_pme.rst | 7 ++ docs/src/references/utils/tuning.rst | 22 ---- examples/01-charges-example.py | 9 +- examples/02-neighbor-lists-usage.py | 10 +- examples/10-tuning.py | 6 +- src/torchpme/__init__.py | 2 +- src/torchpme/tuning/error_bounds.py | 109 +++++++++++++--- src/torchpme/tuning/ewald.py | 29 ++--- src/torchpme/tuning/p3m.py | 8 +- src/torchpme/tuning/pme.py | 11 +- src/torchpme/tuning/tuner.py | 132 ++++++++++++++++---- 16 files changed, 313 insertions(+), 102 deletions(-) create mode 100644 docs/src/references/tuning/base_classes.rst create mode 100644 docs/src/references/tuning/index.rst create mode 100644 docs/src/references/tuning/tune_ewald.rst create mode 100644 docs/src/references/tuning/tune_p3m.rst create mode 100644 docs/src/references/tuning/tune_pme.rst delete mode 100644 docs/src/references/utils/tuning.rst diff --git a/docs/src/references/index.rst b/docs/src/references/index.rst index 656fa061..29c1bb4a 100644 --- a/docs/src/references/index.rst +++ b/docs/src/references/index.rst @@ -25,4 +25,5 @@ refer to the :ref:`userdoc-how-to` section. metatensor lib/index utils/index + tuning/index changelog diff --git a/docs/src/references/tuning/base_classes.rst b/docs/src/references/tuning/base_classes.rst new file mode 100644 index 00000000..860efffd --- /dev/null +++ b/docs/src/references/tuning/base_classes.rst @@ -0,0 +1,15 @@ +Base Classes +############ + +.. autoclass:: torchpme.tuning.tuner.TunerBase + :members: + +.. autoclass:: torchpme.tuning.tuner.GridSearchTuner + :members: + +.. autoclass:: torchpme.tuning.tuner.TuningTimings + :members: + +.. autoclass:: torchpme.tuning.error_bounds.TuningErrorBounds + :members: + diff --git a/docs/src/references/tuning/index.rst b/docs/src/references/tuning/index.rst new file mode 100644 index 00000000..5333dd03 --- /dev/null +++ b/docs/src/references/tuning/index.rst @@ -0,0 +1,40 @@ +Tuning +###### + +The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the +``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the +calculation. To help find the parameters that meet the accuracy requirements, this +module offers tuning methods for the calculators. + +The scheme behind all tuning functions is grid-searching based, focusing on the Fourier +space parameters like ``lr_wavelength``, ``mesh_spacing`` and ``interpolation_nodes``. +For real space parameter ``cutoff``, it is treated as a hyperparameter here, which +should be manually specified by the user. The parameter ``smearing`` is determined by +the real space error formula and is set to achieve a real space error of +``desired_accuracy / 4``. + +The Fourier space parameters are all discrete, so it's convenient to do the grid-search. +Default searching-ranges are provided for those parameters. For ``lr_wavelength``, the +values are chosen to be with a minimum of 1 and a maximum of 13 mesh points in each +spatial direction ``(x, y, z)``. For ``mesh_spacing``, the values are set to have +minimally 2 and maximally 7 mesh points in each spatial direction, for both the P3M and +PME method. The values of ``interpolation_nodes`` are the same as those supported in +:class:`torchpme.lib.MeshInterpolator`. + +In the grid-searching, all possible parameter combinations are evaluated. The error +associated with the parameter is estimated by the error formulas implemented in the +subclasses of :class:`torchpme.tuning.error_bounds.TuningErrorBounds`. Parameter with +the error within the desired accuracy are benchmarked for computational time by +:class:`torchpme.tuning.tuner.TuningTimings` The timing of the other parameters are +not tested and set to infinity. + +The return of these tuning functions contains the ``smearing`` and a dictionary, in +which there is parameter for the Fourier space. The parameter is that of the desired +accuracy and the shortest timing. The parameter of the smallest error will be returned +in the case that no parameter can fulfill the accuracy requirement. + +.. toctree:: + :maxdepth: 1 + :glob: + + ./* diff --git a/docs/src/references/tuning/tune_ewald.rst b/docs/src/references/tuning/tune_ewald.rst new file mode 100644 index 00000000..b33c6683 --- /dev/null +++ b/docs/src/references/tuning/tune_ewald.rst @@ -0,0 +1,7 @@ +Tune Ewald +########## + +.. autofunction:: torchpme.tuning.ewald.tune_ewald + +.. autoclass:: torchpme.tuning.error_bounds.EwaldErrorBounds + :members: diff --git a/docs/src/references/tuning/tune_p3m.rst b/docs/src/references/tuning/tune_p3m.rst new file mode 100644 index 00000000..f561712b --- /dev/null +++ b/docs/src/references/tuning/tune_p3m.rst @@ -0,0 +1,7 @@ +Tune P3M +######### + +.. autofunction:: torchpme.tuning.p3m.tune_p3m + +.. autoclass:: torchpme.tuning.error_bounds.P3MErrorBounds + :members: diff --git a/docs/src/references/tuning/tune_pme.rst b/docs/src/references/tuning/tune_pme.rst new file mode 100644 index 00000000..54c66776 --- /dev/null +++ b/docs/src/references/tuning/tune_pme.rst @@ -0,0 +1,7 @@ +Tune PME +######### + +.. autofunction:: torchpme.tuning.pme.tune_pme + +.. autoclass:: torchpme.tuning.error_bounds.PMEErrorBounds + :members: diff --git a/docs/src/references/utils/tuning.rst b/docs/src/references/utils/tuning.rst deleted file mode 100644 index 7e47fbef..00000000 --- a/docs/src/references/utils/tuning.rst +++ /dev/null @@ -1,22 +0,0 @@ -Tuning -###### - -The choice of parameters like the neighborlist ``cutoff``, the ``smearing`` or the -``lr_wavelength``/``mesh_spacing`` has a large influence one the accuracy of the -calculation. To help find the parameters that meet the accuracy requirements, this -module offers tuning methods for the calculators. - -The scheme behind all tuning functions is a gradient-based optimization, which tries to -find the minimal of the error estimation formula and stops after the error is smaller -than the given accuracy. Because these methods are gradient-based, be sure to pay -attention to the ``learning_rate`` and ``max_steps`` parameter. A good choice of these -two parameters can enhance the optimization speed and performance. - -.. autoclass:: torchpme.utils.tuning.ewald.EwaldTuner - :members: - -.. autoclass:: torchpme.utils.tuning.pme.PMETuner - :members: - -.. autoclass:: torchpme.utils.tuning.p3m.P3MTuner - :members: diff --git a/examples/01-charges-example.py b/examples/01-charges-example.py index 8f3b0f8c..fea5ec13 100644 --- a/examples/01-charges-example.py +++ b/examples/01-charges-example.py @@ -37,7 +37,7 @@ from metatensor.torch.atomistic import NeighborListOptions, System import torchpme -from torchpme.utils.tuning.pme import PMETuner +from torchpme.tuning import tune_pme # %% # @@ -57,9 +57,10 @@ # The ``sum_squared_charges`` is equal to ``2.0`` becaue each atom either has a charge # of 1 or -1 in units of elementary charges. -smearing, pme_params, cutoff = PMETuner( - charges=charges, cell=cell, positions=positions, cutoff=4.4 -).tune() +cutoff = 4.4 +smearing, pme_params = tune_pme( + charges=charges, cell=cell, positions=positions, cutoff=cutoff +) # %% # diff --git a/examples/02-neighbor-lists-usage.py b/examples/02-neighbor-lists-usage.py index 2026b732..dba448f0 100644 --- a/examples/02-neighbor-lists-usage.py +++ b/examples/02-neighbor-lists-usage.py @@ -46,7 +46,7 @@ import vesin.torch import torchpme -from torchpme.utils.tuning.pme import PMETuner +from torchpme.tuning import tune_pme # %% # @@ -93,10 +93,10 @@ cell = torch.from_numpy(atoms.cell.array) sum_squared_charges = float(torch.sum(charges**2)) - -smearing, pme_params, cutoff = PMETuner( - charges=charges, cell=cell, positions=positions, cutoff=4.4 -).tune() +cutoff = 4.4 +smearing, pme_params = tune_pme( + charges=charges, cell=cell, positions=positions, cutoff=cutoff +) # %% # diff --git a/examples/10-tuning.py b/examples/10-tuning.py index fb5a8a8f..5a36e2d7 100644 --- a/examples/10-tuning.py +++ b/examples/10-tuning.py @@ -19,8 +19,8 @@ import vesin.torch as vesin import torchpme -from torchpme.utils.tuning import TuningTimings -from torchpme.utils.tuning.pme import PMEErrorBounds +from torchpme.tuning.error_bounds import PMEErrorBounds +from torchpme.tuning.tuner import TuningTimings DTYPE = torch.float64 @@ -235,7 +235,7 @@ def timed_madelung(cutoff, smearing, mesh_spacing, interpolation_nodes): # %% -EB = torchpme.utils.tuning.pme.PMEErrorBounds((charges**2).sum(), cell, positions) +EB = torchpme.tuning.error_bounds.PMEErrorBounds((charges**2).sum(), cell, positions) # %% v, t = timed_madelung(cutoff=5, smearing=1, mesh_spacing=1, interpolation_nodes=4) diff --git a/src/torchpme/__init__.py b/src/torchpme/__init__.py index 8687dca8..27c85d84 100644 --- a/src/torchpme/__init__.py +++ b/src/torchpme/__init__.py @@ -1,6 +1,6 @@ import contextlib -from . import calculators, lib, potentials, utils # noqa +from . import calculators, lib, potentials, tuning, utils # noqa from ._version import __version__, __version_tuple__ # noqa from .calculators import Calculator, EwaldCalculator, P3MCalculator, PMECalculator from .potentials import ( diff --git a/src/torchpme/tuning/error_bounds.py b/src/torchpme/tuning/error_bounds.py index 3baf5465..aff62129 100644 --- a/src/torchpme/tuning/error_bounds.py +++ b/src/torchpme/tuning/error_bounds.py @@ -1,12 +1,22 @@ import math -import torch +import torch TWO_PI = 2 * math.pi class TuningErrorBounds(torch.nn.Module): - """Base class for error bounds.""" + """ + Base class for error bounds. This class calculates the real space error and the + Fourier space error based on the error formula. This class is used in the tuning + process. It can also be used with the :class:`torchpme.tuning.tuner.TunerBase` to + build up a custom parameter tuner. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + """ def __init__( self, @@ -21,7 +31,7 @@ def __init__( def forward(self, *args, **kwargs): return self.error(*args, **kwargs) - + class EwaldErrorBounds(TuningErrorBounds): r""" @@ -60,7 +70,17 @@ def __init__( self.cell = cell self.positions = positions - def err_kspace(self, smearing, lr_wavelength): + def err_kspace( + self, smearing: torch.Tensor, lr_wavelength: torch.Tensor + ) -> torch.Tensor: + """ + The Fourier space error of Ewald. + + :param smearing: see :class:`torchpme.EwaldCalculator` for details + :param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details + """ + smearing = torch.as_tensor(smearing) + lr_wavelength = torch.as_tensor(lr_wavelength) return ( self.prefac**0.5 / smearing @@ -68,14 +88,22 @@ def err_kspace(self, smearing, lr_wavelength): * torch.exp(-(TWO_PI**2) * smearing**2 / (lr_wavelength)) ) - def err_rspace(self, smearing, cutoff): + def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + The real space error of Ewald. + + :param smearing: see :class:`torchpme.EwaldCalculator` for details + :param lr_wavelength: see :class:`torchpme.EwaldCalculator` for details + """ return ( self.prefac / torch.sqrt(cutoff * self.volume) * torch.exp(-(cutoff**2) / 2 / smearing**2) ) - def forward(self, smearing, lr_wavelength, cutoff): + def forward( + self, smearing: float, lr_wavelength: float, cutoff: float + ) -> torch.Tensor: r""" Calculate the error bound of Ewald. @@ -180,7 +208,19 @@ def __init__( self.cell = cell self.positions = positions - def err_kspace(self, smearing, mesh_spacing, interpolation_nodes): + def err_kspace( + self, + smearing: torch.Tensor, + mesh_spacing: torch.Tensor, + interpolation_nodes: torch.Tensor, + ) -> torch.Tensor: + """ + The Fourier space error of P3M. + + :param smearing: see :class:`torchpme.P3MCalculator` for details + :param mesh_spacing: see :class:`torchpme.P3MCalculator` for details + :param interpolation_nodes: see :class:`torchpme.P3MCalculator` for details + """ actual_spacing = self.cell_dimensions / ( 2 * self.cell_dimensions / mesh_spacing + 1 ) @@ -202,25 +242,33 @@ def err_kspace(self, smearing, mesh_spacing, interpolation_nodes): ) ) - def err_rspace(self, smearing, cutoff): + def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + The real space error of P3M. + + :param smearing: see :class:`torchpme.P3MCalculator` for details + :param cutoff: see :class:`torchpme.P3MCalculator` for details + """ return ( self.prefac / torch.sqrt(cutoff * self.volume) * torch.exp(-(cutoff**2) / 2 / smearing**2) ) - def forward(self, smearing, mesh_spacing, cutoff, interpolation_nodes): + def forward( + self, + smearing: float, + mesh_spacing: float, + cutoff: float, + interpolation_nodes: float, + ) -> torch.Tensor: r""" Calculate the error bound of P3M. :param smearing: see :class:`torchpme.P3MCalculator` for details :param mesh_spacing: see :class:`torchpme.P3MCalculator` for details :param cutoff: see :class:`torchpme.P3MCalculator` for details - :param interpolation_nodes: The number ``n`` of nodes used in the interpolation - per coordinate axis. The total number of interpolation nodes in 3D will be - ``n^3``. In general, for ``n`` nodes, the interpolation will be performed by - piecewise polynomials of degree ``n`` (e.g. ``n = 3`` for cubic - interpolation). Only the values ``1, 2, 3, 4, 5`` are supported. + :param interpolation_nodes: see :class:`torchpme.P3MCalculator` for details """ smearing = torch.as_tensor(smearing) mesh_spacing = torch.as_tensor(mesh_spacing) @@ -230,7 +278,7 @@ def forward(self, smearing, mesh_spacing, cutoff, interpolation_nodes): self.err_kspace(smearing, mesh_spacing, interpolation_nodes) ** 2 + self.err_rspace(smearing, cutoff) ** 2 ) - + class PMEErrorBounds(TuningErrorBounds): r""" @@ -258,7 +306,19 @@ def __init__( self.prefac = 2 * self.sum_squared_charges / math.sqrt(len(positions)) self.cell_dimensions = torch.linalg.norm(cell, dim=1) - def err_kspace(self, smearing, mesh_spacing, interpolation_nodes): + def err_kspace( + self, + smearing: torch.Tensor, + mesh_spacing: torch.Tensor, + interpolation_nodes: torch.Tensor, + ) -> torch.Tensor: + """ + The Fourier space error of PME. + + :param smearing: see :class:`torchpme.PMECalculator` for details + :param mesh_spacing: see :class:`torchpme.PMECalculator` for details + :param interpolation_nodes: see :class:`torchpme.PMECalculator` for details + """ actual_spacing = self.cell_dimensions / ( 2 * self.cell_dimensions / mesh_spacing + 1 ) @@ -279,7 +339,13 @@ def err_kspace(self, smearing, mesh_spacing, interpolation_nodes): * RMS_phi[interpolation_nodes - 1] ) - def err_rspace(self, smearing, cutoff): + def err_rspace(self, smearing: torch.Tensor, cutoff: torch.Tensor) -> torch.Tensor: + """ + The real space error of PME. + + :param smearing: see :class:`torchpme.PMECalculator` for details + :param cutoff: see :class:`torchpme.PMECalculator` for details + """ smearing = torch.as_tensor(smearing) cutoff = torch.as_tensor(cutoff) @@ -289,7 +355,13 @@ def err_rspace(self, smearing, cutoff): * torch.exp(-(cutoff**2) / 2 / smearing**2) ) - def error(self, cutoff, smearing, mesh_spacing, interpolation_nodes): + def error( + self, + cutoff: float, + smearing: float, + mesh_spacing: float, + interpolation_nodes: float, + ) -> torch.Tensor: r""" Calculate the error bound of PME. @@ -313,4 +385,3 @@ def error(self, cutoff, smearing, mesh_spacing, interpolation_nodes): self.err_rspace(smearing, cutoff) ** 2 + self.err_kspace(smearing, mesh_spacing, interpolation_nodes) ** 2 ) - diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index 76188e9c..a0e3b5d4 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -20,7 +20,7 @@ def tune_ewald( ns_lo: int = 1, ns_hi: int = 14, accuracy: float = 1e-3, -): +) -> tuple[float, dict[str, float]]: r""" Find the optimal parameters for :class:`torchpme.EwaldCalculator`. @@ -39,22 +39,21 @@ def tune_ewald( :param charges: torch.Tensor, atomic (pseudo-)charges :param cell: torch.Tensor, periodic supercell for the system - :param positions: torch.Tensor, Cartesian coordinates of the particles within - the supercell. + :param positions: torch.Tensor, Cartesian coordinates of the particles within the + supercell. :param cutoff: float, cutoff distance for the neighborlist - :param exponent :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` is - supported + :param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` + is supported :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for which the potential should be computed in real space. - :param neighbor_distances: torch.Tensor with the pair distances of the neighbors - for which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. :param accuracy: Recomended values for a balance between the accuracy and speed is :math:`10^{-3}`. For more accurate results, use :math:`10^{-6}`. :return: Tuple containing a float of the optimal smearing for the :class: - `CoulombPotential`, a dictionary with the parameters for - :class:`EwaldCalculator` and a float of the optimal cutoff value for the - neighborlist computation. + `CoulombPotential`, and a dictionary with the parameters for + :class:`EwaldCalculator`. Example ------- @@ -64,7 +63,7 @@ def tune_ewald( ... ) >>> charges = torch.tensor([[1.0], [-1.0]], dtype=torch.float64) >>> cell = torch.eye(3, dtype=torch.float64) - >>> smearing, parameter, cutoff = tune_ewald( + >>> smearing, parameter = tune_ewald( ... charges, cell, positions, cutoff=4.4, accuracy=1e-1 ... ) @@ -74,13 +73,11 @@ def tune_ewald( 1.7140874893066034 >>> print(parameter) - {'lr_wavelength': 0.25} - - >>> print(cutoff) - 4.4 + {'lr_wavelength': 0.3333333333333333} """ - params = [{"lr_wavelength": ns} for ns in range(ns_lo, ns_hi + 1)] + min_dimension = float(torch.min(torch.linalg.norm(cell, dim=1))) + params = [{"lr_wavelength": min_dimension / ns} for ns in range(ns_lo, ns_hi + 1)] tuner = GridSearchTuner( charges, cell, diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index b3e44021..7d01976a 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -96,7 +96,7 @@ def tune_p3m( :param positions: torch.Tensor, Cartesian coordinates of the particles within the supercell. :param cutoff: float, cutoff distance for the neighborlist - :param exponent :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` is + :param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` is supported :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for which the potential should be computed in real space. @@ -132,7 +132,7 @@ def tune_p3m( 1.7140874893066034 >>> print(parameter) - {'interpolation_nodes': 2, 'mesh_spacing': 0.2857142857142857} + {'interpolation_nodes': 3, 'mesh_spacing': 0.6666666666666666} >>> print(cutoff) 4.4 @@ -142,9 +142,9 @@ def tune_p3m( params = [ { "interpolation_nodes": interpolation_nodes, - "mesh_spacing": 2 * min_dimension / (2**mesh_spacing - 1), + "mesh_spacing": 2 * min_dimension / (2**ns - 1), } - for interpolation_nodes, mesh_spacing in product( + for interpolation_nodes, ns in product( range(nodes_lo, nodes_hi + 1), range(mesh_lo, mesh_hi + 1) ) ] diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index ec983f97..6550ccfe 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -39,7 +39,7 @@ def tune_pme( :param positions: torch.Tensor, Cartesian coordinates of the particles within the supercell. :param cutoff: float, cutoff distance for the neighborlist - :param exponent :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` is + :param exponent: :math:`p` in :math:`1/r^p` potentials, currently only :math:`p=1` is supported :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for which the potential should be computed in real space. @@ -75,19 +75,16 @@ def tune_pme( 1.7140874893066034 >>> print(parameter) - {'interpolation_nodes': 3, 'mesh_spacing': 0.2857142857142857} - - >>> print(cutoff) - 4.4 + {'interpolation_nodes': 4, 'mesh_spacing': 0.6666666666666666} """ min_dimension = float(torch.min(torch.linalg.norm(cell, dim=1))) params = [ { "interpolation_nodes": interpolation_nodes, - "mesh_spacing": 2 * min_dimension / (2**mesh_spacing - 1), + "mesh_spacing": 2 * min_dimension / (2**ns - 1), } - for interpolation_nodes, mesh_spacing in product( + for interpolation_nodes, ns in product( range(nodes_lo, nodes_hi + 1), range(mesh_lo, mesh_hi + 1) ) ] diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index bae34b9f..36d56908 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -11,6 +11,23 @@ class TunerBase: + """ + Base class defining the interface for a parameter tuner. + + This class provides a framework for tuning the parameters of a calculator. The class + itself supports estimating the ``smearing`` from the real space cutoff based on the + real space error formula. The :func:`TunerBase.tune` defines the interface for a + sophisticated tuning process, which takes a value of the desired accuracy. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + :param cutoff: real space cutoff, serves as a hyperparameter here. + :param calculator: the calculator to be tuned + :param exponent: exponent of the potential, only exponent = 1 is supported + """ + def __init__( self, charges: torch.Tensor, @@ -18,35 +35,19 @@ def __init__( positions: torch.Tensor, cutoff: float, calculator: type[Calculator], - params: list[dict], exponent: int = 1, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_distances: Optional[torch.Tensor] = None, ): self._validate_parameters(charges, cell, positions, exponent) self.charges = charges self.cell = cell self.positions = positions self.cutoff = cutoff - self.exponent = calculator.exponent + self.calculator = calculator + self.exponent = exponent self._dtype = cell.dtype self._device = cell.device - self.calculator = calculator - self.params = params - self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions)) - self.time_func = TuningTimings( - charges, - cell, - positions, - cutoff, - neighbor_indices, - neighbor_distances, - 4, - 2, - True, - ) def tune(self, accuracy: float = 1e-3): raise NotImplementedError @@ -131,7 +132,13 @@ def estimate_smearing( self, accuracy: float, ) -> float: - """Estimate the smearing based on the error formula of the real space.""" + """ + Estimate the smearing based on the error formula of the real space. The + smearing is set as leading to a real space error of ``accuracy/4``. + + :param accuracy: a float, the desired accuracy + :return: a float, the estimated smearing + """ ratio = math.sqrt( -2 * math.log( @@ -147,7 +154,68 @@ def estimate_smearing( class GridSearchTuner(TunerBase): - def tune(self, accuracy: float = 1e-3): + """ + Tuner using grid search. + + The tuner uses the error formula to estimate the error of a given parameter set. + If the error is smaller than the accuracy, the timing is measured and returned. + If the error is larger than the accuracy, the timing is set to infinity and the + parameter is skipped. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + :param cutoff: real space cutoff, serves as a hyperparameter here. + :param calculator: the calculator to be tuned + :param params: list of Fourier space parameter sets for which the error is estimated + :param exponent: exponent of the potential, only exponent = 1 is supported + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors + for which the potential should be computed in real space. + """ + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + cutoff: float, + calculator: type[Calculator], + params: list[dict], + exponent: int = 1, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_distances: Optional[torch.Tensor] = None, + ): + super().__init__( + charges, + cell, + positions, + cutoff, + calculator, + exponent, + ) + self.params = params + self.time_func = TuningTimings( + charges, + cell, + positions, + cutoff, + neighbor_indices, + neighbor_distances, + True, + ) + + def tune(self, accuracy: float = 1e-3) -> tuple[list[float], list[float]]: + """ + Estimate the error and timing for each parameter set. Only parameters for + which the error is smaller than the accuracy are timed, the others' timing is + set to infinity. + + :param accuracy: a float, the desired accuracy + :return: a list of errors and a list of timings + """ if self.calculator is EwaldCalculator: error_bounds = EwaldErrorBounds(self.charges, self.cell, self.positions) elif self.calculator is PMECalculator: @@ -184,7 +252,26 @@ def _timing(self, smearing: float, k_space_params: dict): class TuningTimings(torch.nn.Module): - """Base class for error bounds.""" + """ + Class for timing a calculator. + + The class estimates the average execution time of a given calculater after several + warmup runs. The class takes the information of the structure that one wants to + benchmark on, and the configuration of the timing process as inputs. + + :param charges: atomic charges + :param cell: single tensor of shape (3, 3), describing the bounding + :param positions: single tensor of shape (``len(charges), 3``) containing the + Cartesian positions of all point charges in the system. + :param cutoff: real space cutoff, serves as a hyperparameter here. + :param neighbor_indices: torch.Tensor with the ``i,j`` indices of neighbors for + which the potential should be computed in real space. + :param neighbor_distances: torch.Tensor with the pair distances of the neighbors for + which the potential should be computed in real space. + :param n_repeat: number of times to repeat to estimate the average timing + :param n_warmup: number of warmup runs + :param run_backward: whether to run the backward pass + """ def __init__( self, @@ -231,6 +318,9 @@ def forward(self, calculator: torch.nn.Module): """ Estimate the execution time of a given calculator for the structure to be used as benchmark. + + :param calculator: the calculator to be tuned + :return: a float, the average execution time """ for _ in range(self._n_warmup): result = calculator.forward(