Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
ceriottm committed Dec 29, 2024
1 parent 757add8 commit 3eaf7bb
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 13 deletions.
4 changes: 1 addition & 3 deletions examples/10-tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@

# and this is how long it took to run with these parameters (est.)

timings = TuningTimings(charges, cell, positions,
cutoff=max_cutoff,
run_backward=True)
timings = TuningTimings(charges, cell, positions, cutoff=max_cutoff, run_backward=True)
estimated_timing = timings(pme)

print(f"""
Expand Down
11 changes: 5 additions & 6 deletions src/torchpme/utils/tuning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import math
import time
from typing import Optional

import time
import torch
import vesin.torch

Expand Down Expand Up @@ -121,7 +121,7 @@ def __init__(

def forward(self, *args, **kwargs):
return self.error(*args, **kwargs)


class TuningTimings(torch.nn.Module):
"""Base class for error bounds."""
Expand Down Expand Up @@ -172,7 +172,6 @@ def forward(self, calculator: torch.nn.Module):
Estimate the execution time of a given calculator for the structure
to be used as benchmark.
"""

for _ in range(self._n_warmup):
result = calculator.forward(
positions=self._positions,
Expand All @@ -194,14 +193,14 @@ def forward(self, calculator: torch.nn.Module):
positions.requires_grad_(True)
cell.requires_grad_(True)
charges.requires_grad_(True)
execution_time -= time.time()
execution_time -= time.time()
result = calculator.forward(
positions=positions,
charges=charges,
cell=cell,
neighbor_indices=self._neighbor_indices,
neighbor_distances=self._neighbor_distances,
)
)
value = result.sum()
if self._run_backward:
value.backward(retain_graph=True)
Expand All @@ -210,4 +209,4 @@ def forward(self, calculator: torch.nn.Module):
torch.cuda.synchronize()
execution_time += time.time()

return execution_time / self._n_repeat
return execution_time / self._n_repeat
17 changes: 13 additions & 4 deletions src/torchpme/utils/tuning/grid_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
_validate_parameters,
)


class GridSearchBase:
r"""
Base class for finding the optimal parameters for calculators using a grid search.
Expand Down Expand Up @@ -60,9 +61,17 @@ def __init__(
self.device = charges.device
self.err_func = self.ErrorBounds(charges, cell, positions)
self._cell_dimensions = torch.linalg.norm(cell, dim=1)
self.time_func = self.Timings(charges, cell, positions, cutoff,
neighbor_indices, neighbor_distances,
4, 2, True)
self.time_func = self.Timings(
charges,
cell,
positions,
cutoff,
neighbor_indices,
neighbor_distances,
4,
2,
True,
)

self._prefac = 2 * (charges**2).sum() / math.sqrt(len(positions))

Expand Down Expand Up @@ -156,4 +165,4 @@ def _timing(self, smearing: float, params: dict):
**params,
)

return self.time_func(calculator)
return self.time_func(calculator)

0 comments on commit 3eaf7bb

Please sign in to comment.