From 8e977cbb12dad792a245a03105ebd86cb96ed571 Mon Sep 17 00:00:00 2001 From: Joel Niklaus Date: Wed, 4 Dec 2024 03:48:53 -0800 Subject: [PATCH] Speed up Bootstrapping Computation (#409) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Added fix for heavy recomputation of sample level metrics. * Moved parallelization to where it is actually useful. --------- Co-authored-by: Clémentine Fourrier <22726840+clefourrier@users.noreply.github.com> Co-authored-by: Nathan Habib <30601243+NathanHB@users.noreply.github.com> --- src/lighteval/metrics/stderr.py | 38 +++++++++++++++------------------ 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/lighteval/metrics/stderr.py b/src/lighteval/metrics/stderr.py index 679543966..388521a6c 100644 --- a/src/lighteval/metrics/stderr.py +++ b/src/lighteval/metrics/stderr.py @@ -26,7 +26,7 @@ import math import random -from typing import Callable +from typing import Callable, Optional import numpy as np from scipy.stats import bootstrap @@ -45,9 +45,9 @@ def mean_stderr(arr): class _bootstrap_internal: - def __init__(self, metric: Callable, number_draws: int): - self.metric = metric + def __init__(self, number_draws: int, metric: Optional[Callable] = None): self.number_draws = number_draws + self.metric = metric def __call__(self, cur_experiment): # Creates number_draws samplings (with replacement) of the population by iterating on a given seed @@ -55,8 +55,17 @@ def __call__(self, cur_experiment): rnd = random.Random() rnd.seed(seed) samplings = [] - for _ in range(self.number_draws): - samplings.append(self.metric(rnd.choices(population, k=len(population)))) + import multiprocessing as mp + + with mp.Pool(mp.cpu_count()) as pool: + samplings = pool.starmap( + self.metric, + tqdm( + [(rnd.choices(population, k=len(population)),) for _ in range(self.number_draws)], + total=self.number_draws, + desc="Sampling bootstrap iterations", + ), + ) return samplings @@ -65,28 +74,15 @@ def bootstrap_stderr(metric: Callable, population: list, number_experiments: int by sampling said population for number_experiments and recomputing the metric on the different samplings. """ - import multiprocessing as mp - - pool = mp.Pool(mp.cpu_count()) - res = [] number_draws = min(1000, number_experiments) - # We change the seed every 1000 re-samplings - # and do the experiment 1000 re-samplings at a time number_seeds = number_experiments // number_draws - hlog(f"Bootstrapping {metric.__name__}'s stderr.") - for cur_bootstrap in tqdm( - pool.imap( - _bootstrap_internal(metric=metric, number_draws=number_draws), - ((population, seed) for seed in range(number_seeds)), - ), - total=number_seeds, - ): + hlog(f"Bootstrapping {metric.__name__}'s stderr with {number_seeds} seeds.") + for seed in range(number_seeds): # sample w replacement - res.extend(cur_bootstrap) + res.extend(_bootstrap_internal(metric=metric, number_draws=number_draws)((population, seed))) - pool.close() return mean_stderr(res)