Skip to content

Commit

Permalink
Rearrange again
Browse files Browse the repository at this point in the history
  • Loading branch information
GardevoirX committed Jan 15, 2025
1 parent 630a526 commit fe6b045
Show file tree
Hide file tree
Showing 7 changed files with 327 additions and 262 deletions.
2 changes: 1 addition & 1 deletion src/torchpme/tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .ewald import tune_ewald
from .pme import tune_pme
from .p3m import tune_p3m
from .pme import tune_pme

__all__ = [
"tune_ewald",
Expand Down
247 changes: 0 additions & 247 deletions src/torchpme/tuning/base.py

This file was deleted.

22 changes: 20 additions & 2 deletions src/torchpme/tuning/error_bounds.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,28 @@
import math
import torch

from .base import TuningErrorBounds

TWO_PI = 2 * math.pi


class TuningErrorBounds(torch.nn.Module):
"""Base class for error bounds."""

def __init__(
self,
charges: torch.Tensor,
cell: torch.Tensor,
positions: torch.Tensor,
):
super().__init__()
self._charges = charges
self._cell = cell
self._positions = positions

def forward(self, *args, **kwargs):
return self.error(*args, **kwargs)


class EwaldErrorBounds(TuningErrorBounds):
r"""
Error bounds for :class:`torchpme.calculators.ewald.EwaldCalculator`.
Expand Down Expand Up @@ -72,7 +90,7 @@ def forward(self, smearing, lr_wavelength, cutoff):
self.err_kspace(smearing, lr_wavelength) ** 2
+ self.err_rspace(smearing, cutoff) ** 2
)


# Coefficients for the P3M Fourier error,
# see Table II of http://dx.doi.org/10.1063/1.477415
Expand Down
10 changes: 6 additions & 4 deletions src/torchpme/tuning/ewald.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import math
from typing import Optional

import torch

from ..calculators import EwaldCalculator
from .tuner import GridSearchTuner

TWO_PI = 2 * math.pi


def tune_ewald(
charges: torch.Tensor,
Expand Down Expand Up @@ -76,7 +80,6 @@ def tune_ewald(
4.4
"""

params = [{"lr_wavelength": ns} for ns in range(ns_lo, ns_hi + 1)]
tuner = GridSearchTuner(
charges,
Expand All @@ -97,6 +100,5 @@ def tune_ewald(
# calculation time. The timing of those parameters leading to an higher error
# than the accuracy are set to infinity
return smearing, params[timings.index(min(timings))]
else:
# No parameter meets the requirement, return the one with the smallest error
return smearing, params[errs.index(min(errs))]
# No parameter meets the requirement, return the one with the smallest error
return smearing, params[errs.index(min(errs))]
Loading

0 comments on commit fe6b045

Please sign in to comment.