From a41f780727a85709a70404421573408c46be39f9 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Tue, 7 Jan 2025 22:39:06 +0100 Subject: [PATCH] Fix doctests and remove orphan functions --- src/torchpme/utils/tuning/__init__.py | 100 -------------------- src/torchpme/utils/tuning/ewald.py | 5 +- src/torchpme/utils/tuning/grid_search.py | 111 +++++++++++++++++++++-- src/torchpme/utils/tuning/p3m.py | 8 +- src/torchpme/utils/tuning/pme.py | 8 +- 5 files changed, 118 insertions(+), 114 deletions(-) diff --git a/src/torchpme/utils/tuning/__init__.py b/src/torchpme/utils/tuning/__init__.py index 0f5a6640..50cb8fef 100644 --- a/src/torchpme/utils/tuning/__init__.py +++ b/src/torchpme/utils/tuning/__init__.py @@ -1,4 +1,3 @@ -import math import time from typing import Optional @@ -6,105 +5,6 @@ import vesin.torch -def _estimate_smearing_cutoff( - cell: torch.Tensor, - smearing: Optional[float], - cutoff: Optional[float], - accuracy: float, - prefac: float, -) -> tuple[float, float]: - cell_dimensions = torch.linalg.norm(cell, dim=1) - min_dimension = float(torch.min(cell_dimensions)) - half_cell = min_dimension / 2.0 - cutoff_init = min(5.0, half_cell) if cutoff is None else cutoff - ratio = math.sqrt( - -2 - * math.log( - accuracy - / 2 - / prefac - * math.sqrt(cutoff_init * float(torch.abs(cell.det()))) - ) - ) - smearing_init = cutoff_init / ratio if smearing is None else smearing - - return float(smearing_init), float(cutoff_init) - - -def _validate_parameters( - charges: torch.Tensor, - cell: torch.Tensor, - positions: torch.Tensor, - exponent: int, -) -> None: - if exponent != 1: - raise NotImplementedError("Only exponent = 1 is supported") - - if list(positions.shape) != [len(positions), 3]: - raise ValueError( - "each `positions` must be a tensor with shape [n_atoms, 3], got at least " - f"one tensor with shape {list(positions.shape)}" - ) - - # check shape, dtype and device of cell - dtype = positions.dtype - if cell.dtype != dtype: - raise ValueError( - f"each `cell` must have the same type {dtype} as `positions`, got at least " - "one tensor of type " - f"{cell.dtype}" - ) - - device = positions.device - if cell.device != device: - raise ValueError( - f"each `cell` must be on the same device {device} as `positions`, got at " - "least one tensor with device " - f"{cell.device}" - ) - - if list(cell.shape) != [3, 3]: - raise ValueError( - "each `cell` must be a tensor with shape [3, 3], got at least one tensor " - f"with shape {list(cell.shape)}" - ) - - if torch.equal(cell.det(), torch.full([], 0, dtype=cell.dtype, device=cell.device)): - raise ValueError( - "provided `cell` has a determinant of 0 and therefore is not valid for " - "periodic calculation" - ) - - if charges.dtype != dtype: - raise ValueError( - f"each `charges` must have the same type {dtype} as `positions`, got at least " - "one tensor of type " - f"{charges.dtype}" - ) - - if charges.device != device: - raise ValueError( - f"each `charges` must be on the same device {device} as `positions`, got at " - "least one tensor with device " - f"{charges.device}" - ) - - if charges.dim() != 2: - raise ValueError( - "`charges` must be a 2-dimensional tensor, got " - f"tensor with {charges.dim()} dimension(s) and shape " - f"{list(charges.shape)}" - ) - - if list(charges.shape) != [len(positions), charges.shape[1]]: - raise ValueError( - "`charges` must be a tensor with shape [n_atoms, n_channels], with " - "`n_atoms` being the same as the variable `positions`. Got tensor with " - f"shape {list(charges.shape)} where positions contains " - f"{len(positions)} atoms" - ) - - class TuningErrorBounds(torch.nn.Module): """Base class for error bounds.""" diff --git a/src/torchpme/utils/tuning/ewald.py b/src/torchpme/utils/tuning/ewald.py index 23952480..4009e097 100644 --- a/src/torchpme/utils/tuning/ewald.py +++ b/src/torchpme/utils/tuning/ewald.py @@ -91,7 +91,7 @@ class EwaldTuner(GridSearchBase): ErrorBounds = EwaldErrorBounds CalculatorClass = EwaldCalculator - GridSearchParams = {"lr_wavelength": 1 / np.arange(1, 15)} + TemplateGridSearchParams = {"lr_wavelength": 1 / np.arange(1, 15)} def __init__( self, @@ -115,6 +115,9 @@ def __init__( self.GridSearchParams["lr_wavelength"] *= float( torch.min(self._cell_dimensions) ) + self.GridSearchParams["lr_wavelength"] = list( + self.GridSearchParams["lr_wavelength"] + ) def tune_ewald( diff --git a/src/torchpme/utils/tuning/grid_search.py b/src/torchpme/utils/tuning/grid_search.py index e0105860..b74166bc 100644 --- a/src/torchpme/utils/tuning/grid_search.py +++ b/src/torchpme/utils/tuning/grid_search.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import math +from copy import copy from itertools import product from typing import Optional from warnings import warn @@ -14,8 +15,6 @@ from . import ( TuningErrorBounds, TuningTimings, - _estimate_smearing_cutoff, - _validate_parameters, ) @@ -39,7 +38,9 @@ class GridSearchBase: ErrorBounds: type[TuningErrorBounds] Timings: type[TuningTimings] = TuningTimings CalculatorClass: type[Calculator] - GridSearchParams: dict[str, torch.Tensor] # {"interpolation_nodes": ..., ...} + TemplateGridSearchParams: dict[ + str, torch.Tensor + ] # {"interpolation_nodes": ..., ...} def __init__( self, @@ -51,7 +52,7 @@ def __init__( neighbor_indices: Optional[torch.Tensor] = None, neighbor_distances: Optional[torch.Tensor] = None, ): - _validate_parameters(charges, cell, positions, exponent) + self._validate_parameters(charges, cell, positions, exponent) self.charges = charges self.cell = cell self.positions = positions @@ -74,6 +75,101 @@ def __init__( ) self._prefac = 2 * (charges**2).sum() / math.sqrt(len(positions)) + self.GridSearchParams = copy(self.TemplateGridSearchParams) + + @staticmethod + def _validate_parameters( + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + exponent: int, + ) -> None: + if exponent != 1: + raise NotImplementedError("Only exponent = 1 is supported") + + if list(positions.shape) != [len(positions), 3]: + raise ValueError( + "each `positions` must be a tensor with shape [n_atoms, 3], got at least " + f"one tensor with shape {list(positions.shape)}" + ) + + # check shape, dtype and device of cell + dtype = positions.dtype + if cell.dtype != dtype: + raise ValueError( + f"each `cell` must have the same type {dtype} as `positions`, got at least " + "one tensor of type " + f"{cell.dtype}" + ) + + device = positions.device + if cell.device != device: + raise ValueError( + f"each `cell` must be on the same device {device} as `positions`, got at " + "least one tensor with device " + f"{cell.device}" + ) + + if list(cell.shape) != [3, 3]: + raise ValueError( + "each `cell` must be a tensor with shape [3, 3], got at least one tensor " + f"with shape {list(cell.shape)}" + ) + + if torch.equal( + cell.det(), torch.full([], 0, dtype=cell.dtype, device=cell.device) + ): + raise ValueError( + "provided `cell` has a determinant of 0 and therefore is not valid for " + "periodic calculation" + ) + + if charges.dtype != dtype: + raise ValueError( + f"each `charges` must have the same type {dtype} as `positions`, got at least " + "one tensor of type " + f"{charges.dtype}" + ) + + if charges.device != device: + raise ValueError( + f"each `charges` must be on the same device {device} as `positions`, got at " + "least one tensor with device " + f"{charges.device}" + ) + + if charges.dim() != 2: + raise ValueError( + "`charges` must be a 2-dimensional tensor, got " + f"tensor with {charges.dim()} dimension(s) and shape " + f"{list(charges.shape)}" + ) + + if list(charges.shape) != [len(positions), charges.shape[1]]: + raise ValueError( + "`charges` must be a tensor with shape [n_atoms, n_channels], with " + "`n_atoms` being the same as the variable `positions`. Got tensor with " + f"shape {list(charges.shape)} where positions contains " + f"{len(positions)} atoms" + ) + + def _estimate_smearing( + self, + accuracy: float, + ) -> float: + """Estimate the smearing based on the error formula of the real space.""" + ratio = math.sqrt( + -2 + * math.log( + accuracy + / 2 + / self._prefac + * math.sqrt(self.cutoff * float(torch.abs(self.cell.det()))) + ) + ) + smearing_init = self.cutoff / ratio + + return float(smearing_init) def tune( self, @@ -113,13 +209,10 @@ def tune( cutoff_err_opt = None err_opt = torch.inf - smearing, cutoff = _estimate_smearing_cutoff( - self.cell, - smearing=None, - cutoff=self.cutoff, + smearing = self._estimate_smearing( accuracy=accuracy, - prefac=self._prefac, ) + cutoff = self.cutoff for param_values in product(*self.GridSearchParams.values()): params = dict(zip(self.GridSearchParams.keys(), param_values)) err = self.err_func( diff --git a/src/torchpme/utils/tuning/p3m.py b/src/torchpme/utils/tuning/p3m.py index 7f765c06..de11a762 100644 --- a/src/torchpme/utils/tuning/p3m.py +++ b/src/torchpme/utils/tuning/p3m.py @@ -160,9 +160,10 @@ class P3MTuner(GridSearchBase): ErrorBounds = P3MErrorBounds CalculatorClass = P3MCalculator - GridSearchParams = { + TemplateGridSearchParams = { "interpolation_nodes": [2, 3, 4, 5], - "mesh_spacing": 1 / ((np.exp2(np.arange(2, 8)) - 1) / 2), + "mesh_spacing": 1 + / ((np.exp2(np.arange(2, 8)) - 1) / 2), # will be converted into a list later } def __init__( @@ -185,6 +186,9 @@ def __init__( neighbor_distances, ) self.GridSearchParams["mesh_spacing"] *= float(torch.min(self._cell_dimensions)) + self.GridSearchParams["mesh_spacing"] = list( + self.GridSearchParams["mesh_spacing"] + ) def tune_p3m( diff --git a/src/torchpme/utils/tuning/pme.py b/src/torchpme/utils/tuning/pme.py index 0152d0c4..e48356e5 100644 --- a/src/torchpme/utils/tuning/pme.py +++ b/src/torchpme/utils/tuning/pme.py @@ -104,9 +104,10 @@ class PMETuner(GridSearchBase): ErrorBounds = PMEErrorBounds CalculatorClass = PMECalculator - GridSearchParams = { + TemplateGridSearchParams = { "interpolation_nodes": [3, 4, 5, 6, 7], - "mesh_spacing": 1 / ((np.exp2(np.arange(2, 8)) - 1) / 2), + "mesh_spacing": 1 + / ((np.exp2(np.arange(2, 8)) - 1) / 2), # will be converted into a list later } def __init__( @@ -129,6 +130,9 @@ def __init__( neighbor_distances, ) self.GridSearchParams["mesh_spacing"] *= float(torch.min(self._cell_dimensions)) + self.GridSearchParams["mesh_spacing"] = list( + self.GridSearchParams["mesh_spacing"] + ) def tune_pme(