Skip to content

Commit

Permalink
Update num-test-cases to be compatible with multiple LLMs and num_que…
Browse files Browse the repository at this point in the history
…ries_per_prompt

Summary: The variable `num_test_cases` is defined to specify the number of prompts that need to be evaluated. However, if we are testing multiple LLMs or if we are conducting multiple rounds of testing queries per prompt, these factors will not be reflected in num_test_cases during the expansion/judge step. This discrepancy can lead to a mismatch between the number of responses and the final result count.

Reviewed By: YueLi28

Differential Revision: D56826808

fbshipit-source-id: b1fde2cca5d62788fa0ed3ab8df4ac5965a1be55
  • Loading branch information
SimonWan authored and facebook-github-bot committed May 1, 2024
1 parent 68f1f5e commit 9d74699
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion CybersecurityBenchmarks/benchmark/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ def main(
"--num-test-cases",
type=int,
default=0,
help="Only set this argument for testing purpose: If set to a positive value, we test the given number of cases instead of all cases.",
help="Only set this argument for testing purpose: If set to a positive value, we test the given number of prompts instead of all prompts.",
)

parser.add_argument(
Expand Down Expand Up @@ -214,6 +214,14 @@ def main(

if not args.only_generate_llm_responses:
LOG.info("Evaluating...")
# The variable num_test_cases is defined to specify the number of prompts that need to be evaluated. However, if we are testing multiple LLMs or if we are conducting multiple rounds of testing queries per prompt, these factors will need below steps to reflect in `num_test_cases`` during the expansion/judge step.
if not args.use_precomputed_responses:
if len(materialized_llms_under_test) > 0:
args.num_test_cases = (
len(materialized_llms_under_test) * args.num_test_cases
)
if args.num_queries_per_prompt:
args.num_test_cases = args.num_test_cases * args.num_queries_per_prompt
asyncio.run(benchmark.run(args.num_test_cases))


Expand Down

0 comments on commit 9d74699

Please sign in to comment.