Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
hkiyomaru committed Dec 18, 2023
1 parent df1f932 commit b392de1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
4 changes: 2 additions & 2 deletions llm_judge/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,9 @@ def estimate_cost(self) -> float:
)
num_output_tokens = 200 # Estimated from a few samples
if self.judge.model == "gpt-4":
return 0.03 * num_input_tokens + 0.06 * num_output_tokens / 1_000
return (0.03 * num_input_tokens + 0.06 * num_output_tokens) / 1_000
elif self.judge.model == "gpt-3.5-turbo":
return 0.001 * num_input_tokens + 0.002 * num_output_tokens / 1_000
return (0.001 * num_input_tokens + 0.002 * num_output_tokens) / 1_000
raise AssertionError


Expand Down
5 changes: 3 additions & 2 deletions llm_judge/gen_judgment.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def make_match_groups_pairwise(
ref_answers=ref_answers,
judge_default=Judge(args.judge_model, judge_prompts["pair"]),
judge_math=Judge(args.judge_model, judge_prompts["pair-math"]),
baseline_model=args.baseline_model,
baseline_model=baseline_model,
)
output_dir = JUDGEMENT_DIR / "pairwise" / args.judge_model
target_match_ids = set()
Expand All @@ -256,7 +256,8 @@ def make_match_groups_pairwise(

logger.info(f"Mode: {args.mode}")
logger.info(f"Judge model: {args.judge_model}")
logger.info(f"Baseline model: {args.baseline_model}")
if args.mode == "pairwise-baseline":
logger.info(f"Baseline model: {args.baseline_model}")
logger.info(f"Total number of questions: {len(questions):,}")
logger.info(
f"Total number of matches: {sum(len(matches) for matches in match_groups.values()):,}"
Expand Down

0 comments on commit b392de1

Please sign in to comment.