Skip to content

Commit

Permalink
Speed up Bootstrapping Computation (#409)
Browse files Browse the repository at this point in the history
* Added fix for heavy recomputation of sample level metrics.

* Moved parallelization to where it is actually useful.

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
Co-authored-by: Nathan Habib <[email protected]>
  • Loading branch information
3 people authored Dec 4, 2024
1 parent 9bfa1ea commit 8e977cb
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions src/lighteval/metrics/stderr.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import math
import random
from typing import Callable
from typing import Callable, Optional

import numpy as np
from scipy.stats import bootstrap
Expand All @@ -45,18 +45,27 @@ def mean_stderr(arr):


class _bootstrap_internal:
def __init__(self, metric: Callable, number_draws: int):
self.metric = metric
def __init__(self, number_draws: int, metric: Optional[Callable] = None):
self.number_draws = number_draws
self.metric = metric

def __call__(self, cur_experiment):
# Creates number_draws samplings (with replacement) of the population by iterating on a given seed
population, seed = cur_experiment
rnd = random.Random()
rnd.seed(seed)
samplings = []
for _ in range(self.number_draws):
samplings.append(self.metric(rnd.choices(population, k=len(population))))
import multiprocessing as mp

with mp.Pool(mp.cpu_count()) as pool:
samplings = pool.starmap(
self.metric,
tqdm(
[(rnd.choices(population, k=len(population)),) for _ in range(self.number_draws)],
total=self.number_draws,
desc="Sampling bootstrap iterations",
),
)
return samplings


Expand All @@ -65,28 +74,15 @@ def bootstrap_stderr(metric: Callable, population: list, number_experiments: int
by sampling said population for number_experiments and recomputing the metric on the
different samplings.
"""
import multiprocessing as mp

pool = mp.Pool(mp.cpu_count())

res = []
number_draws = min(1000, number_experiments)
# We change the seed every 1000 re-samplings
# and do the experiment 1000 re-samplings at a time
number_seeds = number_experiments // number_draws

hlog(f"Bootstrapping {metric.__name__}'s stderr.")
for cur_bootstrap in tqdm(
pool.imap(
_bootstrap_internal(metric=metric, number_draws=number_draws),
((population, seed) for seed in range(number_seeds)),
),
total=number_seeds,
):
hlog(f"Bootstrapping {metric.__name__}'s stderr with {number_seeds} seeds.")
for seed in range(number_seeds):
# sample w replacement
res.extend(cur_bootstrap)
res.extend(_bootstrap_internal(metric=metric, number_draws=number_draws)((population, seed)))

pool.close()
return mean_stderr(res)


Expand Down

0 comments on commit 8e977cb

Please sign in to comment.