Skip to content

Commit

Permalink
Fill majority safer version
Browse files Browse the repository at this point in the history
Signed-off-by: Igor Gitman <[email protected]>
  • Loading branch information
Kipok committed Dec 14, 2024
1 parent f16ad00 commit 026ff64
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions nemo_skills/evaluation/fill_majority_answer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import logging
import shutil
import sys
from collections import Counter
from itertools import zip_longest
Expand Down Expand Up @@ -106,8 +107,12 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
for file_handle in file_handles:
file_handle.close()

# writing the majority answers back to the files
file_handles = [open(file, "wt", encoding="utf-8") for file in unroll_files(cfg.input_files)]
# TODO: change to instead write to a fully new set of files
# Create temp filenames and open temp files for writing
input_files = unroll_files(cfg.input_files)
temp_files = [f"{file}-tmp" for file in input_files]
file_handles = [open(temp_file, "wt", encoding="utf-8") for temp_file in temp_files]

for idx, predictions in enumerate(all_predictions):
for lidx, handle in enumerate(file_handles):
if cfg.ignore_if_not_none and predictions[lidx][cfg.fill_key] is not None:
Expand All @@ -118,7 +123,6 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
predictions[lidx]["majority_votes"], predictions[lidx]["total_votes"] = new_answers[idx][1]
else:
predictions[lidx]["answer_rm_score"] = new_answers[idx][1]
# this is just a string match check, so for full correctness need to rerun the evaluator
if cfg.fill_is_correct:
predictions[lidx]["is_correct"] = (
predictions[lidx]["predicted_answer"] == predictions[lidx]["expected_answer"]
Expand All @@ -127,8 +131,14 @@ def fill_majority_answer(cfg: FillMajorityAnswerConfig):
predictions[lidx].pop("is_correct")
handle.write(json.dumps(predictions[lidx]) + "\n")

for file_handle in file_handles:
file_handle.close()
# Close all files before moving
for handle in file_handles:
handle.close()

# Move temp files to original files
input_files = unroll_files(cfg.input_files)
for temp_file, orig_file in zip(temp_files, input_files):
shutil.move(temp_file, orig_file)


HELP_MESSAGE = get_help_message(FillMajorityAnswerConfig)
Expand Down

0 comments on commit 026ff64

Please sign in to comment.