From fe6b04590ee6701a61f8e80768dd6c6d083407e5 Mon Sep 17 00:00:00 2001 From: GardevoirX Date: Wed, 15 Jan 2025 14:36:15 +0100 Subject: [PATCH] Rearrange again --- src/torchpme/tuning/__init__.py | 2 +- src/torchpme/tuning/base.py | 247 ---------------------------- src/torchpme/tuning/error_bounds.py | 22 ++- src/torchpme/tuning/ewald.py | 10 +- src/torchpme/tuning/p3m.py | 65 +++++++- src/torchpme/tuning/pme.py | 8 +- src/torchpme/tuning/tuner.py | 235 +++++++++++++++++++++++++- 7 files changed, 327 insertions(+), 262 deletions(-) delete mode 100644 src/torchpme/tuning/base.py diff --git a/src/torchpme/tuning/__init__.py b/src/torchpme/tuning/__init__.py index 0759562e..f76ca50a 100644 --- a/src/torchpme/tuning/__init__.py +++ b/src/torchpme/tuning/__init__.py @@ -1,6 +1,6 @@ from .ewald import tune_ewald -from .pme import tune_pme from .p3m import tune_p3m +from .pme import tune_pme __all__ = [ "tune_ewald", diff --git a/src/torchpme/tuning/base.py b/src/torchpme/tuning/base.py deleted file mode 100644 index df8d9258..00000000 --- a/src/torchpme/tuning/base.py +++ /dev/null @@ -1,247 +0,0 @@ -import math -from typing import Optional - -import torch -import vesin.torch - - -class TunerBase: - def __init__( - self, - charges: torch.Tensor, - cell: torch.Tensor, - positions: torch.Tensor, - cutoff: float, - calculator, - params: list[dict], - exponent: int = 1, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_distances: Optional[torch.Tensor] = None, - ): - self._validate_parameters(charges, cell, positions, exponent) - self.charges = charges - self.cell = cell - self.positions = positions - self.cutoff = cutoff - self.exponent = exponent - self._dtype = cell.dtype - self._device = cell.device - - self.calculator = calculator - self.params = params - - self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions)) - self.time_func = TuningTimings( - charges, - cell, - positions, - cutoff, - neighbor_indices, - neighbor_distances, - 4, - 2, - True, - ) - - def tune(self, accuracy: float = 1e-3): - raise NotImplementedError - - @staticmethod - def _validate_parameters( - charges: torch.Tensor, - cell: torch.Tensor, - positions: torch.Tensor, - exponent: int, - ) -> None: - if exponent != 1: - raise NotImplementedError("Only exponent = 1 is supported") - - if list(positions.shape) != [len(positions), 3]: - raise ValueError( - "each `positions` must be a tensor with shape [n_atoms, 3], got at least " - f"one tensor with shape {list(positions.shape)}" - ) - - # check shape, dtype and device of cell - dtype = positions.dtype - if cell.dtype != dtype: - raise ValueError( - f"each `cell` must have the same type {dtype} as `positions`, got at least " - "one tensor of type " - f"{cell.dtype}" - ) - - device = positions.device - if cell.device != device: - raise ValueError( - f"each `cell` must be on the same device {device} as `positions`, got at " - "least one tensor with device " - f"{cell.device}" - ) - - if list(cell.shape) != [3, 3]: - raise ValueError( - "each `cell` must be a tensor with shape [3, 3], got at least one tensor " - f"with shape {list(cell.shape)}" - ) - - if torch.equal( - cell.det(), torch.full([], 0, dtype=cell.dtype, device=cell.device) - ): - raise ValueError( - "provided `cell` has a determinant of 0 and therefore is not valid for " - "periodic calculation" - ) - - if charges.dtype != dtype: - raise ValueError( - f"each `charges` must have the same type {dtype} as `positions`, got at least " - "one tensor of type " - f"{charges.dtype}" - ) - - if charges.device != device: - raise ValueError( - f"each `charges` must be on the same device {device} as `positions`, got at " - "least one tensor with device " - f"{charges.device}" - ) - - if charges.dim() != 2: - raise ValueError( - "`charges` must be a 2-dimensional tensor, got " - f"tensor with {charges.dim()} dimension(s) and shape " - f"{list(charges.shape)}" - ) - - if list(charges.shape) != [len(positions), charges.shape[1]]: - raise ValueError( - "`charges` must be a tensor with shape [n_atoms, n_channels], with " - "`n_atoms` being the same as the variable `positions`. Got tensor with " - f"shape {list(charges.shape)} where positions contains " - f"{len(positions)} atoms" - ) - - def estimate_smearing( - self, - accuracy: float, - ) -> float: - """Estimate the smearing based on the error formula of the real space.""" - ratio = math.sqrt( - -2 - * math.log( - accuracy - / 2 - / self._prefac - * math.sqrt(self.cutoff * float(torch.abs(self.cell.det()))) - ) - ) - smearing = self.cutoff / ratio - - return float(smearing) - -class TuningErrorBounds(torch.nn.Module): - """Base class for error bounds.""" - - def __init__( - self, - charges: torch.Tensor, - cell: torch.Tensor, - positions: torch.Tensor, - ): - super().__init__() - self._charges = charges - self._cell = cell - self._positions = positions - - def forward(self, *args, **kwargs): - return self.error(*args, **kwargs) - - -class TuningTimings(torch.nn.Module): - """Base class for error bounds.""" - - def __init__( - self, - charges: torch.Tensor, - cell: torch.Tensor, - positions: torch.Tensor, - cutoff: float, - neighbor_indices: Optional[torch.Tensor] = None, - neighbor_distances: Optional[torch.Tensor] = None, - n_repeat: int = 4, - n_warmup: int = 2, - run_backward: Optional[bool] = True, - ): - super().__init__() - self._charges = charges - self._cell = cell - self._positions = positions - self._dtype = charges.dtype - self._device = charges.device - self._n_repeat = n_repeat - self._n_warmup = n_warmup - self._run_backward = run_backward - - if neighbor_indices is None and neighbor_distances is None: - nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False) - i, j, neighbor_distances = nl.compute( - points=self._positions.to(dtype=torch.float64, device="cpu"), - box=self._cell.to(dtype=torch.float64, device="cpu"), - periodic=True, - quantities="ijd", - ) - neighbor_indices = torch.stack([i, j], dim=1) - elif neighbor_indices is None or neighbor_distances is None: - raise ValueError( - "If neighbor_indices or neighbor_distances are None, " - "both must be None." - ) - self._neighbor_indices = neighbor_indices.to(device=self._device) - self._neighbor_distances = neighbor_distances.to( - dtype=self._dtype, device=self._device - ) - - def forward(self, calculator: torch.nn.Module): - """ - Estimate the execution time of a given calculator for the structure - to be used as benchmark. - """ - for _ in range(self._n_warmup): - result = calculator.forward( - positions=self._positions, - charges=self._charges, - cell=self._cell, - neighbor_indices=self._neighbor_indices, - neighbor_distances=self._neighbor_distances, - ) - - # measure time - execution_time = 0.0 - - for _ in range(self._n_repeat): - positions = self._positions.clone() - cell = self._cell.clone() - charges = self._charges.clone() - # nb - this won't compute gradiens involving the distances - if self._run_backward: - positions.requires_grad_(True) - cell.requires_grad_(True) - charges.requires_grad_(True) - execution_time -= time.time() - result = calculator.forward( - positions=positions, - charges=charges, - cell=cell, - neighbor_indices=self._neighbor_indices, - neighbor_distances=self._neighbor_distances, - ) - value = result.sum() - if self._run_backward: - value.backward(retain_graph=True) - - if self._device is torch.device("cuda"): - torch.cuda.synchronize() - execution_time += time.time() - - return execution_time / self._n_repeat diff --git a/src/torchpme/tuning/error_bounds.py b/src/torchpme/tuning/error_bounds.py index e4e59865..3baf5465 100644 --- a/src/torchpme/tuning/error_bounds.py +++ b/src/torchpme/tuning/error_bounds.py @@ -1,10 +1,28 @@ import math import torch -from .base import TuningErrorBounds TWO_PI = 2 * math.pi + +class TuningErrorBounds(torch.nn.Module): + """Base class for error bounds.""" + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + ): + super().__init__() + self._charges = charges + self._cell = cell + self._positions = positions + + def forward(self, *args, **kwargs): + return self.error(*args, **kwargs) + + class EwaldErrorBounds(TuningErrorBounds): r""" Error bounds for :class:`torchpme.calculators.ewald.EwaldCalculator`. @@ -72,7 +90,7 @@ def forward(self, smearing, lr_wavelength, cutoff): self.err_kspace(smearing, lr_wavelength) ** 2 + self.err_rspace(smearing, cutoff) ** 2 ) - + # Coefficients for the P3M Fourier error, # see Table II of http://dx.doi.org/10.1063/1.477415 diff --git a/src/torchpme/tuning/ewald.py b/src/torchpme/tuning/ewald.py index d2739551..76188e9c 100644 --- a/src/torchpme/tuning/ewald.py +++ b/src/torchpme/tuning/ewald.py @@ -1,9 +1,13 @@ +import math from typing import Optional + import torch from ..calculators import EwaldCalculator from .tuner import GridSearchTuner +TWO_PI = 2 * math.pi + def tune_ewald( charges: torch.Tensor, @@ -76,7 +80,6 @@ def tune_ewald( 4.4 """ - params = [{"lr_wavelength": ns} for ns in range(ns_lo, ns_hi + 1)] tuner = GridSearchTuner( charges, @@ -97,6 +100,5 @@ def tune_ewald( # calculation time. The timing of those parameters leading to an higher error # than the accuracy are set to infinity return smearing, params[timings.index(min(timings))] - else: - # No parameter meets the requirement, return the one with the smallest error - return smearing, params[errs.index(min(errs))] + # No parameter meets the requirement, return the one with the smallest error + return smearing, params[errs.index(min(errs))] diff --git a/src/torchpme/tuning/p3m.py b/src/torchpme/tuning/p3m.py index 546a1849..b3e44021 100644 --- a/src/torchpme/tuning/p3m.py +++ b/src/torchpme/tuning/p3m.py @@ -1,3 +1,4 @@ +import math from itertools import product from typing import Optional @@ -6,6 +7,65 @@ from ..calculators import P3MCalculator from .tuner import GridSearchTuner +TWO_PI = 2 * math.pi + +# Coefficients for the P3M Fourier error, +# see Table II of http://dx.doi.org/10.1063/1.477415 +A_COEF = [ + [None, 2 / 3, 1 / 50, 1 / 588, 1 / 4320, 1 / 23_232, 691 / 68_140_800, 1 / 345_600], + [ + None, + None, + 5 / 294, + 7 / 1440, + 3 / 1936, + 7601 / 13_628_160, + 13 / 57_600, + 3617 / 35_512_320, + ], + [ + None, + None, + None, + 21 / 3872, + 7601 / 2_271_360, + 143 / 69_120, + 47_021 / 35_512_320, + 745_739 / 838_397_952, + ], + [ + None, + None, + None, + None, + 143 / 28_800, + 517_231 / 106_536_960, + 9_694_607 / 2_095_994_880, + 56_399_353 / 12_773_376_000, + ], + [ + None, + None, + None, + None, + None, + 106_640_677 / 11_737_571_328, + 733_191_589 / 59_609_088_000, + 25_091_609 / 1_560_084_480, + ], + [ + None, + None, + None, + None, + None, + None, + 326_190_917 / 11_700_633_600, + 1_755_948_832_039 / 36_229_939_200_000, + ], + [None, None, None, None, None, None, None, 4_887_769_399 / 37_838_389_248], +] + def tune_p3m( charges: torch.Tensor, @@ -108,6 +168,5 @@ def tune_p3m( # calculation time. The timing of those parameters leading to an higher error # than the accuracy are set to infinity return smearing, params[timings.index(min(timings))] - else: - # No parameter meets the requirement, return the one with the smallest error - return smearing, params[errs.index(min(errs))] + # No parameter meets the requirement, return the one with the smallest error + return smearing, params[errs.index(min(errs))] diff --git a/src/torchpme/tuning/pme.py b/src/torchpme/tuning/pme.py index 2f764a2d..ec983f97 100644 --- a/src/torchpme/tuning/pme.py +++ b/src/torchpme/tuning/pme.py @@ -1,3 +1,4 @@ +import math from itertools import product from typing import Optional @@ -6,6 +7,8 @@ from ..calculators import PMECalculator from .tuner import GridSearchTuner +TWO_PI = 2 * math.pi + def tune_pme( charges: torch.Tensor, @@ -108,6 +111,5 @@ def tune_pme( # calculation time. The timing of those parameters leading to an higher error # than the accuracy are set to infinity return smearing, params[timings.index(min(timings))] - else: - # No parameter meets the requirement, return the one with the smallest error - return smearing, params[errs.index(min(errs))] + # No parameter meets the requirement, return the one with the smallest error + return smearing, params[errs.index(min(errs))] diff --git a/src/torchpme/tuning/tuner.py b/src/torchpme/tuning/tuner.py index 0feffdc4..bae34b9f 100644 --- a/src/torchpme/tuning/tuner.py +++ b/src/torchpme/tuning/tuner.py @@ -1,9 +1,151 @@ -from ..calculators import EwaldCalculator, P3MCalculator, PMECalculator +import math +import time +from typing import Optional + +import torch +import vesin.torch + +from ..calculators import Calculator, EwaldCalculator, P3MCalculator, PMECalculator from ..potentials import CoulombPotential -from .base import TunerBase from .error_bounds import EwaldErrorBounds, P3MErrorBounds, PMEErrorBounds +class TunerBase: + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + cutoff: float, + calculator: type[Calculator], + params: list[dict], + exponent: int = 1, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_distances: Optional[torch.Tensor] = None, + ): + self._validate_parameters(charges, cell, positions, exponent) + self.charges = charges + self.cell = cell + self.positions = positions + self.cutoff = cutoff + self.exponent = calculator.exponent + self._dtype = cell.dtype + self._device = cell.device + + self.calculator = calculator + self.params = params + + self._prefac = 2 * float((charges**2).sum()) / math.sqrt(len(positions)) + self.time_func = TuningTimings( + charges, + cell, + positions, + cutoff, + neighbor_indices, + neighbor_distances, + 4, + 2, + True, + ) + + def tune(self, accuracy: float = 1e-3): + raise NotImplementedError + + @staticmethod + def _validate_parameters( + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + exponent: int, + ) -> None: + if exponent != 1: + raise NotImplementedError("Only exponent = 1 is supported") + + if list(positions.shape) != [len(positions), 3]: + raise ValueError( + "each `positions` must be a tensor with shape [n_atoms, 3], got at least " + f"one tensor with shape {list(positions.shape)}" + ) + + # check shape, dtype and device of cell + dtype = positions.dtype + if cell.dtype != dtype: + raise ValueError( + f"each `cell` must have the same type {dtype} as `positions`, got at least " + "one tensor of type " + f"{cell.dtype}" + ) + + device = positions.device + if cell.device != device: + raise ValueError( + f"each `cell` must be on the same device {device} as `positions`, got at " + "least one tensor with device " + f"{cell.device}" + ) + + if list(cell.shape) != [3, 3]: + raise ValueError( + "each `cell` must be a tensor with shape [3, 3], got at least one tensor " + f"with shape {list(cell.shape)}" + ) + + if torch.equal( + cell.det(), torch.full([], 0, dtype=cell.dtype, device=cell.device) + ): + raise ValueError( + "provided `cell` has a determinant of 0 and therefore is not valid for " + "periodic calculation" + ) + + if charges.dtype != dtype: + raise ValueError( + f"each `charges` must have the same type {dtype} as `positions`, got at least " + "one tensor of type " + f"{charges.dtype}" + ) + + if charges.device != device: + raise ValueError( + f"each `charges` must be on the same device {device} as `positions`, got at " + "least one tensor with device " + f"{charges.device}" + ) + + if charges.dim() != 2: + raise ValueError( + "`charges` must be a 2-dimensional tensor, got " + f"tensor with {charges.dim()} dimension(s) and shape " + f"{list(charges.shape)}" + ) + + if list(charges.shape) != [len(positions), charges.shape[1]]: + raise ValueError( + "`charges` must be a tensor with shape [n_atoms, n_channels], with " + "`n_atoms` being the same as the variable `positions`. Got tensor with " + f"shape {list(charges.shape)} where positions contains " + f"{len(positions)} atoms" + ) + + def estimate_smearing( + self, + accuracy: float, + ) -> float: + """Estimate the smearing based on the error formula of the real space.""" + ratio = math.sqrt( + -2 + * math.log( + accuracy + / 2 + / self._prefac + * math.sqrt(self.cutoff * float(torch.abs(self.cell.det()))) + ) + ) + smearing = self.cutoff / ratio + + return float(smearing) + + class GridSearchTuner(TunerBase): def tune(self, accuracy: float = 1e-3): if self.calculator is EwaldCalculator: @@ -39,3 +181,92 @@ def _timing(self, smearing: float, k_space_params: dict): ) return self.time_func(calculator) + + +class TuningTimings(torch.nn.Module): + """Base class for error bounds.""" + + def __init__( + self, + charges: torch.Tensor, + cell: torch.Tensor, + positions: torch.Tensor, + cutoff: float, + neighbor_indices: Optional[torch.Tensor] = None, + neighbor_distances: Optional[torch.Tensor] = None, + n_repeat: int = 4, + n_warmup: int = 2, + run_backward: Optional[bool] = True, + ): + super().__init__() + self._charges = charges + self._cell = cell + self._positions = positions + self._dtype = charges.dtype + self._device = charges.device + self._n_repeat = n_repeat + self._n_warmup = n_warmup + self._run_backward = run_backward + + if neighbor_indices is None and neighbor_distances is None: + nl = vesin.torch.NeighborList(cutoff=cutoff, full_list=False) + i, j, neighbor_distances = nl.compute( + points=self._positions.to(dtype=torch.float64, device="cpu"), + box=self._cell.to(dtype=torch.float64, device="cpu"), + periodic=True, + quantities="ijd", + ) + neighbor_indices = torch.stack([i, j], dim=1) + elif neighbor_indices is None or neighbor_distances is None: + raise ValueError( + "If neighbor_indices or neighbor_distances are None, " + "both must be None." + ) + self._neighbor_indices = neighbor_indices.to(device=self._device) + self._neighbor_distances = neighbor_distances.to( + dtype=self._dtype, device=self._device + ) + + def forward(self, calculator: torch.nn.Module): + """ + Estimate the execution time of a given calculator for the structure + to be used as benchmark. + """ + for _ in range(self._n_warmup): + result = calculator.forward( + positions=self._positions, + charges=self._charges, + cell=self._cell, + neighbor_indices=self._neighbor_indices, + neighbor_distances=self._neighbor_distances, + ) + + # measure time + execution_time = 0.0 + + for _ in range(self._n_repeat): + positions = self._positions.clone() + cell = self._cell.clone() + charges = self._charges.clone() + # nb - this won't compute gradiens involving the distances + if self._run_backward: + positions.requires_grad_(True) + cell.requires_grad_(True) + charges.requires_grad_(True) + execution_time -= time.time() + result = calculator.forward( + positions=positions, + charges=charges, + cell=cell, + neighbor_indices=self._neighbor_indices, + neighbor_distances=self._neighbor_distances, + ) + value = result.sum() + if self._run_backward: + value.backward(retain_graph=True) + + if self._device is torch.device("cuda"): + torch.cuda.synchronize() + execution_time += time.time() + + return execution_time / self._n_repeat