Skip to content

Commit

Permalink
Fix doctests and remove orphan functions
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Jan 7, 2025
1 parent 56fc382 commit a41f780
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 114 deletions.
100 changes: 0 additions & 100 deletions src/torchpme/utils/tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,110 +1,10 @@
import math
import time
from typing import Optional

import torch
import vesin.torch


def _estimate_smearing_cutoff(
cell: torch.Tensor,
smearing: Optional[float],
cutoff: Optional[float],
accuracy: float,
prefac: float,
) -> tuple[float, float]:
cell_dimensions = torch.linalg.norm(cell, dim=1)
min_dimension = float(torch.min(cell_dimensions))
half_cell = min_dimension / 2.0
cutoff_init = min(5.0, half_cell) if cutoff is None else cutoff
ratio = math.sqrt(
-2
* math.log(
accuracy
/ 2
/ prefac
* math.sqrt(cutoff_init * float(torch.abs(cell.det())))
)
)
smearing_init = cutoff_init / ratio if smearing is None else smearing

return float(smearing_init), float(cutoff_init)


def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
exponent: int,
) -> None:
if exponent != 1:
raise NotImplementedError("Only exponent = 1 is supported")

if list(positions.shape) != [len(positions), 3]:
raise ValueError(
"each `positions` must be a tensor with shape [n_atoms, 3], got at least "
f"one tensor with shape {list(positions.shape)}"
)

# check shape, dtype and device of cell
dtype = positions.dtype
if cell.dtype != dtype:
raise ValueError(
f"each `cell` must have the same type {dtype} as `positions`, got at least "
"one tensor of type "
f"{cell.dtype}"
)

device = positions.device
if cell.device != device:
raise ValueError(
f"each `cell` must be on the same device {device} as `positions`, got at "
"least one tensor with device "
f"{cell.device}"
)

if list(cell.shape) != [3, 3]:
raise ValueError(
"each `cell` must be a tensor with shape [3, 3], got at least one tensor "
f"with shape {list(cell.shape)}"
)

if torch.equal(cell.det(), torch.full([], 0, dtype=cell.dtype, device=cell.device)):
raise ValueError(
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)

if charges.dtype != dtype:
raise ValueError(
f"each `charges` must have the same type {dtype} as `positions`, got at least "
"one tensor of type "
f"{charges.dtype}"
)

if charges.device != device:
raise ValueError(
f"each `charges` must be on the same device {device} as `positions`, got at "
"least one tensor with device "
f"{charges.device}"
)

if charges.dim() != 2:
raise ValueError(
"`charges` must be a 2-dimensional tensor, got "
f"tensor with {charges.dim()} dimension(s) and shape "
f"{list(charges.shape)}"
)

if list(charges.shape) != [len(positions), charges.shape[1]]:
raise ValueError(
"`charges` must be a tensor with shape [n_atoms, n_channels], with "
"`n_atoms` being the same as the variable `positions`. Got tensor with "
f"shape {list(charges.shape)} where positions contains "
f"{len(positions)} atoms"
)


class TuningErrorBounds(torch.nn.Module):
"""Base class for error bounds."""

Expand Down
5 changes: 4 additions & 1 deletion src/torchpme/utils/tuning/ewald.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ class EwaldTuner(GridSearchBase):

ErrorBounds = EwaldErrorBounds
CalculatorClass = EwaldCalculator
GridSearchParams = {"lr_wavelength": 1 / np.arange(1, 15)}
TemplateGridSearchParams = {"lr_wavelength": 1 / np.arange(1, 15)}

def __init__(
self,
Expand All @@ -115,6 +115,9 @@ def __init__(
self.GridSearchParams["lr_wavelength"] *= float(
torch.min(self._cell_dimensions)
)
self.GridSearchParams["lr_wavelength"] = list(
self.GridSearchParams["lr_wavelength"]
)


def tune_ewald(
Expand Down
111 changes: 102 additions & 9 deletions src/torchpme/utils/tuning/grid_search.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#!/usr/bin/env python3

import math
from copy import copy
from itertools import product
from typing import Optional
from warnings import warn
Expand All @@ -14,8 +15,6 @@
from . import (
TuningErrorBounds,
TuningTimings,
_estimate_smearing_cutoff,
_validate_parameters,
)


Expand All @@ -39,7 +38,9 @@ class GridSearchBase:
ErrorBounds: type[TuningErrorBounds]
Timings: type[TuningTimings] = TuningTimings
CalculatorClass: type[Calculator]
GridSearchParams: dict[str, torch.Tensor] # {"interpolation_nodes": ..., ...}
TemplateGridSearchParams: dict[
str, torch.Tensor
] # {"interpolation_nodes": ..., ...}

def __init__(
self,
Expand All @@ -51,7 +52,7 @@ def __init__(
neighbor_indices: Optional[torch.Tensor] = None,
neighbor_distances: Optional[torch.Tensor] = None,
):
_validate_parameters(charges, cell, positions, exponent)
self._validate_parameters(charges, cell, positions, exponent)
self.charges = charges
self.cell = cell
self.positions = positions
Expand All @@ -74,6 +75,101 @@ def __init__(
)

self._prefac = 2 * (charges**2).sum() / math.sqrt(len(positions))
self.GridSearchParams = copy(self.TemplateGridSearchParams)

@staticmethod
def _validate_parameters(
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
exponent: int,
) -> None:
if exponent != 1:
raise NotImplementedError("Only exponent = 1 is supported")

if list(positions.shape) != [len(positions), 3]:
raise ValueError(
"each `positions` must be a tensor with shape [n_atoms, 3], got at least "
f"one tensor with shape {list(positions.shape)}"
)

# check shape, dtype and device of cell
dtype = positions.dtype
if cell.dtype != dtype:
raise ValueError(
f"each `cell` must have the same type {dtype} as `positions`, got at least "
"one tensor of type "
f"{cell.dtype}"
)

device = positions.device
if cell.device != device:
raise ValueError(
f"each `cell` must be on the same device {device} as `positions`, got at "
"least one tensor with device "
f"{cell.device}"
)

if list(cell.shape) != [3, 3]:
raise ValueError(
"each `cell` must be a tensor with shape [3, 3], got at least one tensor "
f"with shape {list(cell.shape)}"
)

if torch.equal(
cell.det(), torch.full([], 0, dtype=cell.dtype, device=cell.device)
):
raise ValueError(
"provided `cell` has a determinant of 0 and therefore is not valid for "
"periodic calculation"
)

if charges.dtype != dtype:
raise ValueError(
f"each `charges` must have the same type {dtype} as `positions`, got at least "
"one tensor of type "
f"{charges.dtype}"
)

if charges.device != device:
raise ValueError(
f"each `charges` must be on the same device {device} as `positions`, got at "
"least one tensor with device "
f"{charges.device}"
)

if charges.dim() != 2:
raise ValueError(
"`charges` must be a 2-dimensional tensor, got "
f"tensor with {charges.dim()} dimension(s) and shape "
f"{list(charges.shape)}"
)

if list(charges.shape) != [len(positions), charges.shape[1]]:
raise ValueError(
"`charges` must be a tensor with shape [n_atoms, n_channels], with "
"`n_atoms` being the same as the variable `positions`. Got tensor with "
f"shape {list(charges.shape)} where positions contains "
f"{len(positions)} atoms"
)

def _estimate_smearing(
self,
accuracy: float,
) -> float:
"""Estimate the smearing based on the error formula of the real space."""
ratio = math.sqrt(
-2
* math.log(
accuracy
/ 2
/ self._prefac
* math.sqrt(self.cutoff * float(torch.abs(self.cell.det())))
)
)
smearing_init = self.cutoff / ratio

return float(smearing_init)

def tune(
self,
Expand Down Expand Up @@ -113,13 +209,10 @@ def tune(
cutoff_err_opt = None
err_opt = torch.inf

smearing, cutoff = _estimate_smearing_cutoff(
self.cell,
smearing=None,
cutoff=self.cutoff,
smearing = self._estimate_smearing(
accuracy=accuracy,
prefac=self._prefac,
)
cutoff = self.cutoff
for param_values in product(*self.GridSearchParams.values()):
params = dict(zip(self.GridSearchParams.keys(), param_values))
err = self.err_func(
Expand Down
8 changes: 6 additions & 2 deletions src/torchpme/utils/tuning/p3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,10 @@ class P3MTuner(GridSearchBase):

ErrorBounds = P3MErrorBounds
CalculatorClass = P3MCalculator
GridSearchParams = {
TemplateGridSearchParams = {
"interpolation_nodes": [2, 3, 4, 5],
"mesh_spacing": 1 / ((np.exp2(np.arange(2, 8)) - 1) / 2),
"mesh_spacing": 1
/ ((np.exp2(np.arange(2, 8)) - 1) / 2), # will be converted into a list later
}

def __init__(
Expand All @@ -185,6 +186,9 @@ def __init__(
neighbor_distances,
)
self.GridSearchParams["mesh_spacing"] *= float(torch.min(self._cell_dimensions))
self.GridSearchParams["mesh_spacing"] = list(
self.GridSearchParams["mesh_spacing"]
)


def tune_p3m(
Expand Down
8 changes: 6 additions & 2 deletions src/torchpme/utils/tuning/pme.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,10 @@ class PMETuner(GridSearchBase):

ErrorBounds = PMEErrorBounds
CalculatorClass = PMECalculator
GridSearchParams = {
TemplateGridSearchParams = {
"interpolation_nodes": [3, 4, 5, 6, 7],
"mesh_spacing": 1 / ((np.exp2(np.arange(2, 8)) - 1) / 2),
"mesh_spacing": 1
/ ((np.exp2(np.arange(2, 8)) - 1) / 2), # will be converted into a list later
}

def __init__(
Expand All @@ -129,6 +130,9 @@ def __init__(
neighbor_distances,
)
self.GridSearchParams["mesh_spacing"] *= float(torch.min(self._cell_dimensions))
self.GridSearchParams["mesh_spacing"] = list(
self.GridSearchParams["mesh_spacing"]
)


def tune_pme(
Expand Down

0 comments on commit a41f780

Please sign in to comment.