From 33c40a92b10e8bc23513aae57f5a3dd1ecacfbc2 Mon Sep 17 00:00:00 2001 From: Albert Villanova del Moral <8515462+albertvillanova@users.noreply.github.com> Date: Mon, 16 Dec 2024 18:06:04 +0100 Subject: [PATCH] Support custom results path --- src/lighteval/logging/evaluation_tracker.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/lighteval/logging/evaluation_tracker.py b/src/lighteval/logging/evaluation_tracker.py index 01705534..bcd29981 100644 --- a/src/lighteval/logging/evaluation_tracker.py +++ b/src/lighteval/logging/evaluation_tracker.py @@ -183,7 +183,7 @@ def save(self) -> None: details_datasets[task_name] = dataset # We save results at every case - self.save_results(date_id, results_dict) + self.save_results(date_id=date_id, results_dict=results_dict) if self.should_save_details: self.save_details(date_id, details_datasets) @@ -200,12 +200,19 @@ def save(self) -> None: results=self.metrics_logger.metric_aggregated, details=self.details_logger.compiled_details ) - def save_results(self, date_id: str, results_dict: dict): - output_dir_results = Path(self.output_dir) / "results" / self.general_config_logger.model_name - self.fs.mkdirs(output_dir_results, exist_ok=True) - output_results_file = output_dir_results / f"results_{date_id}.json" - logger.info(f"Saving results to {output_results_file}") - with self.fs.open(output_results_file, "w") as f: + def save_results( + self, date_id: str | None = None, results_dict: dict | None = None, output_path: str | None = None + ): + if output_path: + fs, output_path = url_to_fs(output_path) + else: + output_path = ( + Path(self.output_dir) / "results" / self.general_config_logger.model_name / f"results_{date_id}.json" + ) + fs = self.fs + fs.mkdirs(output_path.parent, exist_ok=True) + logger.info(f"Saving results to {output_path}") + with fs.open(output_path, "w") as f: f.write(json.dumps(results_dict, cls=EnhancedJSONEncoder, indent=2, ensure_ascii=False)) def save_details(self, date_id: str, details_datasets: dict[str, Dataset]):