Skip to content

Commit

Permalink
Merge branch 'main' into scaler
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster authored Dec 13, 2024
2 parents 2f8ed7e + c52e9f2 commit 53566b8
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions src/metatrain/cli/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,11 +212,15 @@ def _eval_targets(
total_time = 0.0
timings_per_atom = []

# Warm up with a single batch 5 times (to get accurate timings later)
batch = next(iter(dataloader))
systems = batch[0]
systems = [system.to(dtype=dtype, device=device) for system in systems]
for _ in range(5):
# Warm up with a single batch 10 times (to get accurate timings later).
# We use different batches to warm up torch potentially with different
# tensor sizes, so dynamic shape compilation happens. We have to cycle
# the dataloader in case there are few batches.
cycled_dataloader = itertools.cycle(dataloader)
for _ in range(10):
batch = next(cycled_dataloader)
systems = batch[0]
systems = [system.to(dtype=dtype, device=device) for system in systems]
evaluate_model(
model,
systems,
Expand Down Expand Up @@ -282,8 +286,8 @@ def _eval_targets(
std_per_atom = np.std(timings_per_atom)
logger.info(
f"evaluation time: {total_time:.2f} s "
f"[{1000.0*mean_per_atom:.2f} ± "
f"{1000.0*std_per_atom:.2f} ms per atom]"
f"[{1000.0*mean_per_atom:.4f} ± "
f"{1000.0*std_per_atom:.4f} ms per atom]"
)

if return_predictions:
Expand Down

0 comments on commit 53566b8

Please sign in to comment.