diff --git a/nemo_skills/evaluation/fill_majority_answer.py b/nemo_skills/evaluation/fill_majority_answer.py index 04b2c7fa7..cd6e5e7d3 100644 --- a/nemo_skills/evaluation/fill_majority_answer.py +++ b/nemo_skills/evaluation/fill_majority_answer.py @@ -14,6 +14,7 @@ import json import logging +import shutil import sys from collections import Counter from itertools import zip_longest @@ -106,8 +107,12 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): for file_handle in file_handles: file_handle.close() - # writing the majority answers back to the files - file_handles = [open(file, "wt", encoding="utf-8") for file in unroll_files(cfg.input_files)] + # TODO: change to instead write to a fully new set of files + # Create temp filenames and open temp files for writing + input_files = unroll_files(cfg.input_files) + temp_files = [f"{file}-tmp" for file in input_files] + file_handles = [open(temp_file, "wt", encoding="utf-8") for temp_file in temp_files] + for idx, predictions in enumerate(all_predictions): for lidx, handle in enumerate(file_handles): if cfg.ignore_if_not_none and predictions[lidx][cfg.fill_key] is not None: @@ -118,7 +123,6 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): predictions[lidx]["majority_votes"], predictions[lidx]["total_votes"] = new_answers[idx][1] else: predictions[lidx]["answer_rm_score"] = new_answers[idx][1] - # this is just a string match check, so for full correctness need to rerun the evaluator if cfg.fill_is_correct: predictions[lidx]["is_correct"] = ( predictions[lidx]["predicted_answer"] == predictions[lidx]["expected_answer"] @@ -127,8 +131,14 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig): predictions[lidx].pop("is_correct") handle.write(json.dumps(predictions[lidx]) + "\n") - for file_handle in file_handles: - file_handle.close() + # Close all files before moving + for handle in file_handles: + handle.close() + + # Move temp files to original files + input_files = unroll_files(cfg.input_files) + for temp_file, orig_file in zip(temp_files, input_files): + shutil.move(temp_file, orig_file) HELP_MESSAGE = get_help_message(FillMajorityAnswerConfig)