From 9d74699c04a9313d471a06d1e7810b7524e93f26 Mon Sep 17 00:00:00 2001 From: Shengye Wan Date: Wed, 1 May 2024 11:34:46 -0700 Subject: [PATCH] Update num-test-cases to be compatible with multiple LLMs and num_queries_per_prompt Summary: The variable `num_test_cases` is defined to specify the number of prompts that need to be evaluated. However, if we are testing multiple LLMs or if we are conducting multiple rounds of testing queries per prompt, these factors will not be reflected in num_test_cases during the expansion/judge step. This discrepancy can lead to a mismatch between the number of responses and the final result count. Reviewed By: YueLi28 Differential Revision: D56826808 fbshipit-source-id: b1fde2cca5d62788fa0ed3ab8df4ac5965a1be55 --- CybersecurityBenchmarks/benchmark/run.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/CybersecurityBenchmarks/benchmark/run.py b/CybersecurityBenchmarks/benchmark/run.py index 97cb7c846..89cc185e4 100644 --- a/CybersecurityBenchmarks/benchmark/run.py +++ b/CybersecurityBenchmarks/benchmark/run.py @@ -146,7 +146,7 @@ def main( "--num-test-cases", type=int, default=0, - help="Only set this argument for testing purpose: If set to a positive value, we test the given number of cases instead of all cases.", + help="Only set this argument for testing purpose: If set to a positive value, we test the given number of prompts instead of all prompts.", ) parser.add_argument( @@ -214,6 +214,14 @@ def main( if not args.only_generate_llm_responses: LOG.info("Evaluating...") + # The variable num_test_cases is defined to specify the number of prompts that need to be evaluated. However, if we are testing multiple LLMs or if we are conducting multiple rounds of testing queries per prompt, these factors will need below steps to reflect in `num_test_cases`` during the expansion/judge step. + if not args.use_precomputed_responses: + if len(materialized_llms_under_test) > 0: + args.num_test_cases = ( + len(materialized_llms_under_test) * args.num_test_cases + ) + if args.num_queries_per_prompt: + args.num_test_cases = args.num_test_cases * args.num_queries_per_prompt asyncio.run(benchmark.run(args.num_test_cases))