From 8864beecfadf4965b6aa9ac755341362714ce0a2 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Sun, 6 Oct 2024 22:47:04 -0700 Subject: [PATCH] [core] remove beam search from the core (#9105) Signed-off-by: Alvant --- benchmarks/backend_request_func.py | 6 - benchmarks/benchmark_latency.py | 3 +- benchmarks/benchmark_prioritization.py | 24 ++- benchmarks/benchmark_serving.py | 7 - benchmarks/benchmark_throughput.py | 29 ++-- examples/llm_engine_example.py | 3 - examples/multilora_inference.py | 18 --- tests/basic_correctness/test_preemption.py | 114 +------------- tests/conftest.py | 14 -- tests/core/block/e2e/test_correctness.py | 67 -------- tests/core/utils.py | 7 +- tests/samplers/test_beam_search.py | 4 +- tests/samplers/test_sampler.py | 30 +--- vllm/core/scheduler.py | 4 +- vllm/engine/async_llm_engine.py | 16 +- vllm/engine/output_processor/single_step.py | 164 +------------------- vllm/entrypoints/llm.py | 13 +- vllm/entrypoints/openai/protocol.py | 10 +- vllm/envs.py | 5 - vllm/model_executor/layers/sampler.py | 9 +- vllm/outputs.py | 6 +- vllm/sampling_params.py | 73 +-------- vllm/sequence.py | 46 ++---- vllm/utils.py | 19 +++ vllm/worker/tpu_model_runner.py | 3 - 25 files changed, 98 insertions(+), 596 deletions(-) diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index bcd38461617a8..4813fde27f0bc 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -23,7 +23,6 @@ class RequestFuncInput: output_len: int model: str best_of: int = 1 - use_beam_search: bool = False logprobs: Optional[int] = None multi_modal_content: Optional[dict] = None ignore_eos: bool = False @@ -49,7 +48,6 @@ async def async_request_tgi( assert api_url.endswith("generate_stream") async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search params = { "best_of": request_func_input.best_of, "max_new_tokens": request_func_input.output_len, @@ -121,7 +119,6 @@ async def async_request_trt_llm( assert api_url.endswith("generate_stream") async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search assert request_func_input.best_of == 1 payload = { "accumulate_tokens": True, @@ -187,7 +184,6 @@ async def async_request_deepspeed_mii( ) -> RequestFuncOutput: async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: assert request_func_input.best_of == 1 - assert not request_func_input.use_beam_search payload = { "prompt": request_func_input.prompt, @@ -235,7 +231,6 @@ async def async_request_openai_completions( ), "OpenAI Completions API URL must end with 'completions' or 'profile'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search payload = { "model": request_func_input.model, "prompt": request_func_input.prompt, @@ -317,7 +312,6 @@ async def async_request_openai_chat_completions( ), "OpenAI Chat Completions API URL must end with 'chat/completions'." async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session: - assert not request_func_input.use_beam_search content = [{"type": "text", "text": request_func_input.prompt}] if request_func_input.multi_modal_content: content.append(request_func_input.multi_modal_content) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index eadf994cacd34..938d7acd5687c 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -51,9 +51,8 @@ def main(args: argparse.Namespace): sampling_params = SamplingParams( n=args.n, - temperature=0.0 if args.use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=args.use_beam_search, ignore_eos=True, max_tokens=args.output_len, ) diff --git a/benchmarks/benchmark_prioritization.py b/benchmarks/benchmark_prioritization.py index 0ba29fabca59b..8843e3a927a01 100644 --- a/benchmarks/benchmark_prioritization.py +++ b/benchmarks/benchmark_prioritization.py @@ -68,7 +68,6 @@ def run_vllm( tensor_parallel_size: int, seed: int, n: int, - use_beam_search: bool, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], @@ -114,9 +113,8 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) @@ -144,15 +142,16 @@ def main(args: argparse.Namespace): args.output_len) if args.backend == "vllm": - elapsed_time = run_vllm( - requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, - args.trust_remote_code, args.dtype, args.max_model_len, - args.enforce_eager, args.kv_cache_dtype, - args.quantization_param_path, args.device, - args.enable_prefix_caching, args.enable_chunked_prefill, - args.max_num_batched_tokens, args.gpu_memory_utilization, - args.download_dir) + elapsed_time = run_vllm(requests, args.model, args.tokenizer, + args.quantization, args.tensor_parallel_size, + args.seed, args.n, args.trust_remote_code, + args.dtype, args.max_model_len, + args.enforce_eager, args.kv_cache_dtype, + args.quantization_param_path, args.device, + args.enable_prefix_caching, + args.enable_chunked_prefill, + args.max_num_batched_tokens, + args.gpu_memory_utilization, args.download_dir) else: raise ValueError(f"Unknown backend: {args.backend}") total_num_tokens = sum(prompt_len + output_len @@ -203,7 +202,6 @@ def main(args: argparse.Namespace): type=int, default=1, help="Number of generated sequences per prompt.") - parser.add_argument("--use-beam-search", action="store_true") parser.add_argument("--num-prompts", type=int, default=200, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 0460f4c0094be..292d1f37fbf3e 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -391,7 +391,6 @@ async def benchmark( input_requests: List[Tuple[str, int, int]], logprobs: Optional[int], best_of: int, - use_beam_search: bool, request_rate: float, disable_tqdm: bool, profile: bool, @@ -419,7 +418,6 @@ async def benchmark( output_len=test_output_len, logprobs=logprobs, best_of=best_of, - use_beam_search=use_beam_search, multi_modal_content=test_mm_content, ignore_eos=ignore_eos, ) @@ -441,7 +439,6 @@ async def benchmark( output_len=test_output_len, logprobs=logprobs, best_of=best_of, - use_beam_search=use_beam_search, multi_modal_content=test_mm_content, ) profile_output = await request_func(request_func_input=profile_input) @@ -464,7 +461,6 @@ async def benchmark( output_len=output_len, logprobs=logprobs, best_of=best_of, - use_beam_search=use_beam_search, multi_modal_content=mm_content, ) tasks.append( @@ -483,7 +479,6 @@ async def benchmark( output_len=test_output_len, logprobs=logprobs, best_of=best_of, - use_beam_search=use_beam_search, ) profile_output = await request_func(request_func_input=profile_input) if profile_output.success: @@ -679,7 +674,6 @@ def main(args: argparse.Namespace): input_requests=input_requests, logprobs=args.logprobs, best_of=args.best_of, - use_beam_search=args.use_beam_search, request_rate=args.request_rate, disable_tqdm=args.disable_tqdm, profile=args.profile, @@ -701,7 +695,6 @@ def main(args: argparse.Namespace): result_json["model_id"] = model_id result_json["tokenizer_id"] = tokenizer_id result_json["best_of"] = args.best_of - result_json["use_beam_search"] = args.use_beam_search result_json["num_prompts"] = args.num_prompts # Metadata diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index c6bc607ff6b8e..3781863f77e64 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -73,7 +73,6 @@ def run_vllm( tensor_parallel_size: int, seed: int, n: int, - use_beam_search: bool, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], @@ -91,7 +90,6 @@ def run_vllm( download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, disable_async_output_proc: bool = False, - use_new_beam_search_impl: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -127,19 +125,19 @@ def run_vllm( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) - if not use_new_beam_search_impl: + use_beam_search = False + + if not use_beam_search: start = time.perf_counter() llm.generate(prompts, sampling_params, use_tqdm=True) end = time.perf_counter() else: - assert use_beam_search prompts = [prompt for prompt, _, _ in requests] # output_len should be the same for all requests. output_len = requests[0][2] @@ -165,7 +163,6 @@ async def run_vllm_async( tensor_parallel_size: int, seed: int, n: int, - use_beam_search: bool, trust_remote_code: bool, dtype: str, max_model_len: Optional[int], @@ -224,9 +221,8 @@ async def run_vllm_async( sampling_params.append( SamplingParams( n=n, - temperature=0.0 if use_beam_search else 1.0, + temperature=1.0, top_p=1.0, - use_beam_search=use_beam_search, ignore_eos=True, max_tokens=output_len, )) @@ -248,11 +244,9 @@ def run_hf( model: str, tokenizer: PreTrainedTokenizerBase, n: int, - use_beam_search: bool, max_batch_size: int, trust_remote_code: bool, ) -> float: - assert not use_beam_search llm = AutoModelForCausalLM.from_pretrained( model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code) if llm.config.model_type == "llama": @@ -284,7 +278,7 @@ def run_hf( padding=True).input_ids llm_outputs = llm.generate( input_ids=input_ids.cuda(), - do_sample=not use_beam_search, + do_sample=True, num_return_sequences=n, temperature=1.0, top_p=1.0, @@ -340,7 +334,7 @@ def main(args: argparse.Namespace): if args.backend == "vllm": run_args = [ requests, args.model, args.tokenizer, args.quantization, - args.tensor_parallel_size, args.seed, args.n, args.use_beam_search, + args.tensor_parallel_size, args.seed, args.n, args.trust_remote_code, args.dtype, args.max_model_len, args.enforce_eager, args.kv_cache_dtype, args.quantization_param_path, args.device, @@ -355,12 +349,11 @@ def main(args: argparse.Namespace): run_args.append(args.disable_frontend_multiprocessing) elapsed_time = uvloop.run(run_vllm_async(*run_args)) else: - elapsed_time = run_vllm(*run_args, args.use_new_beam_search_impl) + elapsed_time = run_vllm(*run_args) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, - args.use_beam_search, args.hf_max_batch_size, - args.trust_remote_code) + args.hf_max_batch_size, args.trust_remote_code) elif args.backend == "mii": elapsed_time = run_mii(requests, args.model, args.tensor_parallel_size, args.output_len) @@ -414,8 +407,6 @@ def main(args: argparse.Namespace): type=int, default=1, help="Number of generated sequences per prompt.") - parser.add_argument("--use-beam-search", action="store_true") - parser.add_argument("--use-new-beam-search-impl", action="store_true") parser.add_argument("--num-prompts", type=int, default=1000, @@ -570,8 +561,6 @@ def main(args: argparse.Namespace): raise ValueError("dtype must be auto for MII backend.") if args.n != 1: raise ValueError("n must be 1 for MII backend.") - if args.use_beam_search: - raise ValueError("Beam search is not supported for MII backend.") if args.quantization is not None: raise ValueError("Quantization is only for vLLM backend.") if args.hf_max_batch_size is not None: diff --git a/examples/llm_engine_example.py b/examples/llm_engine_example.py index ca41f32b12b31..60d894aae9692 100644 --- a/examples/llm_engine_example.py +++ b/examples/llm_engine_example.py @@ -18,9 +18,6 @@ def create_test_prompts() -> List[Tuple[str, SamplingParams]]: temperature=0.8, top_p=0.95, frequency_penalty=0.1)), - ("It is only with the heart that one can see rightly", - SamplingParams(n=3, best_of=3, use_beam_search=True, - temperature=0.0)), ] diff --git a/examples/multilora_inference.py b/examples/multilora_inference.py index 6aa25b4689ec8..043220d979c3c 100644 --- a/examples/multilora_inference.py +++ b/examples/multilora_inference.py @@ -43,15 +43,6 @@ def create_test_prompts( max_tokens=128, stop_token_ids=[32003]), LoRARequest("sql-lora", 1, lora_path)), - ( - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 - SamplingParams(n=3, - best_of=3, - use_beam_search=True, - temperature=0, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora", 1, lora_path)), ( "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 SamplingParams(temperature=0.0, @@ -60,15 +51,6 @@ def create_test_prompts( max_tokens=128, stop_token_ids=[32003]), LoRARequest("sql-lora2", 2, lora_path)), - ( - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 - SamplingParams(n=3, - best_of=3, - use_beam_search=True, - temperature=0, - max_tokens=128, - stop_token_ids=[32003]), - LoRARequest("sql-lora", 1, lora_path)), ] diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 05e7859759002..4e502cfb5f4f8 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -23,11 +23,9 @@ @pytest.fixture(scope="module", autouse=True) def check_settings(): assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, " - "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. " + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1." "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " - "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest " - "tests/basic_correctness/test_preemption.py`") + "pytest tests/basic_correctness/test_preemption.py`") @pytest.fixture @@ -137,114 +135,6 @@ def test_preemption( assert total_preemption == total_recorded_preemption -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -@pytest.mark.parametrize("beam_width", [4]) -def test_swap( - caplog_vllm, - hf_runner, - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - beam_width: int, - worker_use_ray: bool, -) -> None: - """Use beam search enables swapping.""" - example_prompts = example_prompts[:1] - with hf_runner(model, dtype=dtype) as hf_model: - hf_outputs = hf_model.generate_beam_search(example_prompts, beam_width, - max_tokens) - - with vllm_runner( - model, - dtype=dtype, - swap_space=10, - disable_log_stats=False, - worker_use_ray=worker_use_ray, - ) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search(example_prompts, - beam_width, max_tokens) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - total_preemption = ( - vllm_model.model.llm_engine.scheduler[0].num_cumulative_preemption) - - for i in range(len(example_prompts)): - hf_output_ids, _ = hf_outputs[i] - vllm_output_ids, _ = vllm_outputs[i] - assert len(hf_output_ids) == len(vllm_output_ids) - for j in range(len(hf_output_ids)): - assert hf_output_ids[j] == vllm_output_ids[j], ( - f"Test{i} output{j}:\nHF: {hf_output_ids}\n" - f"vLLM: {vllm_output_ids}") - - assert ("is preempted by PreemptionMode.SWAP mode because there " - "is not enough KV cache space." in caplog_vllm.text) - # Ensure the count bucket of request-level histogram metrics matches - # the number of requests as a simple sanity check to ensure metrics are - # generated - preemption_metrics = None - for m in REGISTRY.collect(): - if m.name == "vllm:num_preemptions": - preemption_metrics = m - assert preemption_metrics is not None - total_recorded_preemption = 0 - for sample in preemption_metrics.samples: - total_recorded_preemption += sample.value - assert total_preemption == total_recorded_preemption - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["float"]) -@pytest.mark.parametrize("max_tokens", [96]) -@pytest.mark.parametrize("beam_width", [4]) -@pytest.mark.parametrize("use_v2_block_manager", [True, False]) -def test_swap_infeasible( - vllm_runner, - example_prompts, - model: str, - dtype: str, - max_tokens: int, - beam_width: int, - worker_use_ray: bool, - use_v2_block_manager: bool, -) -> None: - """Verify infeasible swap request will be ignored.""" - BLOCK_SIZE = 16 - prefill_blocks = 2 - decode_blocks = max_tokens // BLOCK_SIZE - example_prompts = example_prompts[:1] - with vllm_runner( - model, - dtype=dtype, - swap_space=10, - block_size=BLOCK_SIZE, - # Since beam search have more than 1 sequence, prefill + - # decode blocks are not enough to finish. - num_gpu_blocks_override=prefill_blocks + decode_blocks, - max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, - worker_use_ray=worker_use_ray, - use_v2_block_manager=use_v2_block_manager, - ) as vllm_model: - sampling_params = SamplingParams(n=beam_width, - use_beam_search=True, - temperature=0.0, - max_tokens=max_tokens, - ignore_eos=True) - req_outputs = vllm_model.model.generate( - example_prompts, - sampling_params=sampling_params, - ) - assert (vllm_model.model.llm_engine.scheduler[0].artificial_preempt_cnt - < ARTIFICIAL_PREEMPTION_MAX_CNT) - - # Verify the request is ignored and not hang. - assert req_outputs[0].outputs[0].finish_reason == "length" - - @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) diff --git a/tests/conftest.py b/tests/conftest.py index 5de3f1f2a2b90..713be09ca96ea 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -782,7 +782,6 @@ def generate_encoder_decoder_greedy_logprobs( List[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, - use_beam_search=False, max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=(num_prompt_logprobs), @@ -795,19 +794,6 @@ def generate_encoder_decoder_greedy_logprobs( encoder_decoder_prompts, greedy_logprobs_params) def generate_beam_search( - self, - prompts: List[str], - beam_width: int, - max_tokens: int, - ) -> List[Tuple[List[List[int]], List[str]]]: - beam_search_params = SamplingParams(n=beam_width, - use_beam_search=True, - temperature=0.0, - max_tokens=max_tokens) - outputs = self.generate(prompts, beam_search_params) - return outputs - - def generate_beam_search_new( self, prompts: Union[List[str], List[List[int]]], beam_width: int, diff --git a/tests/core/block/e2e/test_correctness.py b/tests/core/block/e2e/test_correctness.py index b3d3667b37d88..033778d2c35e0 100644 --- a/tests/core/block/e2e/test_correctness.py +++ b/tests/core/block/e2e/test_correctness.py @@ -85,73 +85,6 @@ def test_v1_v2_greedy_equality_with_preemption(baseline_llm_generator, assert baseline_token_ids == test_token_ids -@pytest.mark.parametrize( - "common_llm_kwargs", - [{ - # Use a small model for a fast test. - "model": "facebook/opt-125m", - - # skip cuda graph creation for fast test. - "enforce_eager": True, - - # Use a large block size to trigger more copy-on-writes. - "block_size": 32, - }]) -@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) -@pytest.mark.parametrize("baseline_llm_kwargs", [{ - "use_v2_block_manager": False -}]) -@pytest.mark.parametrize("test_llm_kwargs", [{ - "use_v2_block_manager": True, - "preemption_mode": "swap" -}, { - "use_v2_block_manager": True, - "preemption_mode": "recompute" -}]) -@pytest.mark.parametrize("batch_size", [10]) -@pytest.mark.parametrize("seed", [1]) -def test_v1_v2_greedy_equality_with_cow(baseline_llm_generator, - test_llm_generator, batch_size): - """Verify beam search equality with block manager v1 and v2. - - This requires copy-on-writes; if the v1 and v2 output is the same, then - we have some confidence cow is working. - """ - output_len = 128 - temperature = 0.0 - - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - prompts = [prompt for prompt, _ in zip(cycle(prompts), range(batch_size))] - - sampling_params = SamplingParams( - max_tokens=output_len, - ignore_eos=True, - temperature=temperature, - use_beam_search=True, - best_of=2, - ) - - print('Getting token ids from block manager v1') - baseline_token_ids = get_token_ids_from_llm_generator( - baseline_llm_generator, prompts, sampling_params) - - print('Getting token ids from block manager v2') - test_token_ids = get_token_ids_from_llm_generator(test_llm_generator, - prompts, sampling_params) - - for expected_token_ids, actual_token_ids in zip(baseline_token_ids, - test_token_ids): - assert expected_token_ids == actual_token_ids - - assert baseline_token_ids == test_token_ids - - @pytest.mark.parametrize( "common_llm_kwargs", [{ diff --git a/tests/core/utils.py b/tests/core/utils.py index 1e4332268c2f3..a95a573db7cd3 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -13,7 +13,6 @@ def create_dummy_prompt( prompt_length: int, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, best_of: int = 1, prompt_tokens: Optional[List[int]] = None, min_tokens: int = 0, @@ -37,7 +36,6 @@ def create_dummy_prompt( seqs=[prompt], arrival_time=time.time(), sampling_params=SamplingParams( - use_beam_search=use_beam_search, best_of=best_of, max_tokens=max_tokens, min_tokens=min_tokens), @@ -52,7 +50,6 @@ def create_dummy_prompt_encoder_decoder( encoder_prompt_length: int, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, - use_beam_search: bool = False, best_of: int = 1, ) -> Tuple[Sequence, Sequence, SequenceGroup]: if not block_size: @@ -85,9 +82,7 @@ def create_dummy_prompt_encoder_decoder( from_decoder_prompt=False) seq_group = SequenceGroup(request_id=request_id, seqs=[decoder_prompt], - sampling_params=SamplingParams( - use_beam_search=use_beam_search, - best_of=best_of), + sampling_params=SamplingParams(best_of=best_of), arrival_time=time.time(), lora_request=lora_request, encoder_seq=encoder_prompt) diff --git a/tests/samplers/test_beam_search.py b/tests/samplers/test_beam_search.py index a9bedc2956fdd..4d1a6978d4c55 100644 --- a/tests/samplers/test_beam_search.py +++ b/tests/samplers/test_beam_search.py @@ -33,8 +33,8 @@ def test_beam_search_single_input( max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: - vllm_outputs = vllm_model.generate_beam_search_new( - example_prompts, beam_width, max_tokens) + vllm_outputs = vllm_model.generate_beam_search(example_prompts, + beam_width, max_tokens) for i in range(len(example_prompts)): hf_output_ids, hf_output_texts = hf_outputs[i] diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 9d4932dd1f5b1..28c34064f670c 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -159,26 +159,6 @@ def test_sampler_all_random_seed_deterministic(seed: int, device: str): assert first_sampler_output == second_sampler_output -@pytest.mark.parametrize("seed", RANDOM_SEEDS) -@pytest.mark.parametrize("device", CUDA_DEVICES) -def test_sampler_all_beam(seed: int, device: str): - set_random_seed(seed) - torch.set_default_device(device) - batch_size = random.randint(1, 256) - _, fake_logits, sampler = _prepare_test(batch_size) - - sampling_params = SamplingParams( - temperature=0, - best_of=2, - use_beam_search=True, - ) - _do_sample(batch_size, fake_logits, sampler, sampling_params, device) - # no assertion here as I am not sure how to determine whether - # the outputs are expected - in other words, this just tests - # whether there are no exceptions in the sampler - # when handling an all-beam search case. - - @pytest.mark.parametrize("seed", RANDOM_SEEDS) @pytest.mark.parametrize("device", CUDA_DEVICES) def test_sampler_min_tokens_penalty(seed: int, device: str): @@ -479,7 +459,7 @@ def test_sampler_mixed(seed: int, device: str): seq_lens: List[int] = [] for i in range(batch_size): expected: Optional[List[int]] = None - sampling_type = random.randint(0, 3) + sampling_type = random.randint(0, 2) if sampling_type == 0: sampling_params = SamplingParams(temperature=0) expected = [int(torch.argmax(fake_logits[i], dim=-1).item())] @@ -498,10 +478,7 @@ def test_sampler_mixed(seed: int, device: str): for idx in range(n): fake_logits[i, i + idx] = 1e2 expected = list(range(i, i + n)) - else: - sampling_params = SamplingParams(temperature=0, - use_beam_search=True, - best_of=2) + expected_tokens.append(expected) seq_group_metadata_list.append( SequenceGroupMetadata( @@ -530,9 +507,6 @@ def test_sampling(): zip(sampler_output, seq_group_metadata_list)): assert metadata.sampling_params is not None - if metadata.sampling_params.use_beam_search: - continue - if (metadata.sampling_params.seed is not None and expected_tokens[i] is None): # Record seeded random result to compare with results of diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index f3a5016d0e62a..c57e6cd716405 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1202,9 +1202,9 @@ def _can_append_slots(self, seq_group: SequenceGroup, seq_group=seq_group, num_lookahead_slots=num_lookahead_slots) def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + # TODO: does it work with parallel sampling? no_beam_search = seq_group.sampling_params is None or ( - seq_group.sampling_params.best_of == 1 - and not seq_group.sampling_params.use_beam_search) + seq_group.sampling_params.best_of == 1) return no_beam_search def schedule( diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index d16893c706129..b982d20b4c95f 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -33,7 +33,7 @@ from vllm.transformers_utils.tokenizer import AnyTokenizer from vllm.usage.usage_lib import UsageContext from vllm.utils import (collect_from_async_generator, deprecate_kwargs, - random_uuid, weak_bind) + get_beam_search_score, random_uuid, weak_bind) logger = init_logger(__name__) ENGINE_ITERATION_TIMEOUT_S = envs.VLLM_ENGINE_ITERATION_TIMEOUT_S @@ -1092,6 +1092,12 @@ async def beam_search( max_tokens = params.max_tokens ignore_eos = params.ignore_eos temperature = params.temperature + length_penalty = params.length_penalty + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, + length_penalty) tokenizer = await self.get_tokenizer() tokenizedPrompt = prompt if isinstance( @@ -1145,15 +1151,11 @@ async def beam_search( else: new_beams.append(new_beam) - sorted_beams = sorted(new_beams, - key=lambda x: x.cum_logprob, - reverse=True) + sorted_beams = sorted(new_beams, key=sort_beams_key, reverse=True) all_beams = sorted_beams[:beam_width] completed.extend(all_beams) - sorted_completed = sorted(completed, - key=lambda x: x.cum_logprob, - reverse=True) + sorted_completed = sorted(completed, key=sort_beams_key, reverse=True) best_beams = sorted_completed[:beam_width] for beam in best_beams: diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index e288aa0c4aafd..00d9297e41d99 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -6,7 +6,6 @@ SequenceGroupOutputProcessor) from vllm.engine.output_processor.stop_checker import StopChecker from vllm.logger import init_logger -from vllm.sampling_params import SamplingParams from vllm.sequence import (Sequence, SequenceGroup, SequenceGroupOutput, SequenceOutput, SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer @@ -113,7 +112,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, outputs: SequenceGroupOutput, is_async: bool) -> None: sampling_params = seq_group.sampling_params - if sampling_params.best_of == 1 and not sampling_params.use_beam_search: + if sampling_params.best_of == 1: # only have one output sample sample = outputs.samples[0] # only have one sequence @@ -142,7 +141,6 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) - existing_finished_seqs = seq_group.get_finished_seqs() parent_child_dict: Dict[int, List[SequenceOutput]] = { parent_seq.seq_id: [] for parent_seq in parent_seqs @@ -197,106 +195,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, lora_req=seq_group.lora_request, ) - # Non-beam search case - if not sampling_params.use_beam_search: - # For newly created child sequences, add them to the sequence group - # and fork them in block manager if they are not finished. - for seq, parent in child_seqs: - if seq is not parent: - seq_group.add(seq) - if not seq.is_finished(): - for scheduler in self.scheduler: - scheduler.fork_seq(parent, seq) - - # Free the finished and selected parent sequences' memory in block - # manager. Keep them in the sequence group as candidate output. - # NOTE: we need to fork the new sequences before freeing the - # old sequences. - for seq, parent in child_seqs: - if seq is parent and seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) - return - - # Beam search case - # Select the child sequences to keep in the sequence group. - selected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] - unselected_child_seqs: List[Tuple[Sequence, Optional[Sequence]]] = [] - beam_width = sampling_params.best_of - length_penalty = sampling_params.length_penalty - - # Select the newly finished sequences with the highest scores - # to replace existing finished sequences. - # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] - all_finished_seqs = existing_finished_seqs + new_finished_seqs - # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - for seq, parent, is_new in all_finished_seqs[:beam_width]: - if is_new: - # A newly generated child sequence finishes and has a high - # score, so we will add it into the sequence group. - selected_child_seqs.append((seq, parent)) - for seq, parent, is_new in all_finished_seqs[beam_width:]: - if is_new: - # A newly generated child sequence finishes but has a low - # score, so we will not add it into the sequence group. - # Additionally, if this sequence is a continuation of a - # parent sequence, we will need remove the parent sequence - # from the sequence group. - unselected_child_seqs.append((seq, parent)) - else: - # An existing finished sequence has a low score, so we will - # remove it from the sequence group. - seq_group.remove(seq.seq_id) - - # select the top beam_width sequences from the running - # sequences for the next iteration to continue the beam - # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] - # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, eos_token_id=x[0].eos_token_id), - reverse=True) - - # Check if we can stop the beam search. - if len(running_child_seqs) == 0: - # No running sequences, stop the beam search. - stop_beam_search = True - elif len(all_finished_seqs) < beam_width: - # Not enough finished sequences, continue the beam search. - stop_beam_search = False - else: - # Check the early stopping criteria - best_running_seq = running_child_seqs[0][0] - current_worst_seq = all_finished_seqs[beam_width - 1][0] - stop_beam_search = self._check_beam_search_early_stopping( - sampling_params.early_stopping, sampling_params, - best_running_seq, current_worst_seq) - - if stop_beam_search: - # Stop the beam search and remove all the running sequences from - # the sequence group. - unselected_child_seqs.extend(running_child_seqs) - else: - # Continue the beam search and select the top beam_width sequences - # to continue the beam search. - selected_child_seqs.extend(running_child_seqs[:beam_width]) - # The remaining running sequences will not be used in the next - # iteration. Again, if these sequences are continuations of - # parent sequences, we will need to remove the parent sequences - # from the sequence group. - unselected_child_seqs.extend(running_child_seqs[beam_width:]) - # For newly created child sequences, add them to the sequence group # and fork them in block manager if they are not finished. - for seq, parent in selected_child_seqs: + for seq, parent in child_seqs: if seq is not parent: seq_group.add(seq) if not seq.is_finished(): @@ -305,61 +206,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Free the finished and selected parent sequences' memory in block # manager. Keep them in the sequence group as candidate output. - for seq, parent in selected_child_seqs: + # NOTE: we need to fork the new sequences before freeing the + # old sequences. + for seq, parent in child_seqs: if seq is parent and seq.is_finished(): for scheduler in self.scheduler: scheduler.free_seq(seq) - - # Remove the unselected parent sequences from the sequence group and - # free their memory in block manager. - for seq, parent in unselected_child_seqs: - if seq is parent: - # Remove the parent sequence if it is not selected for next - # iteration - seq_group.remove(seq.seq_id) - for scheduler in self.scheduler: - scheduler.free_seq(seq) - - def _check_beam_search_early_stopping( - self, - early_stopping: Union[bool, str], - sampling_params: SamplingParams, - best_running_seq: Sequence, - current_worst_seq: Sequence, - ) -> bool: - assert sampling_params.use_beam_search - length_penalty = sampling_params.length_penalty - if early_stopping is True: - return True - - current_worst_score = current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=current_worst_seq.eos_token_id) - if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id) - else: - assert early_stopping == "never" - if length_penalty > 0.0: - # If length_penalty > 0.0, beam search will prefer longer - # sequences. The highest attainable score calculation is - # based on the longest possible sequence length in this case. - max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id, - seq_len=max_possible_length)) - else: - # Otherwise, beam search will prefer shorter sequences. The - # highest attainable score calculation is based on the current - # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=best_running_seq.eos_token_id)) - return current_worst_score >= highest_attainable_score + return diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 1cb35ee92348d..439f3769f9fbd 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -28,7 +28,8 @@ get_cached_tokenizer) from vllm.transformers_utils.tokenizer_group import TokenizerGroup from vllm.usage.usage_lib import UsageContext -from vllm.utils import Counter, deprecate_kwargs, is_list_of +from vllm.utils import (Counter, deprecate_kwargs, get_beam_search_score, + is_list_of) logger = init_logger(__name__) @@ -404,6 +405,12 @@ def beam_search( max_tokens = params.max_tokens temperature = params.temperature ignore_eos = params.ignore_eos + length_penalty = params.length_penalty + + def sort_beams_key(x: BeamSearchSequence) -> float: + return get_beam_search_score(x.tokens, x.cum_logprob, + tokenizer.eos_token_id, + length_penalty) tokenizer = self.get_tokenizer() # generate 2 * beam_width candidates at each step @@ -466,7 +473,7 @@ def beam_search( else: instance_new_beams.append(new_beam) sorted_beams = sorted(instance_new_beams, - key=lambda x: x.cum_logprob, + key=sort_beams_key, reverse=True) instance.beams = sorted_beams[:beam_width] @@ -474,7 +481,7 @@ def beam_search( for instance in instances: instance.completed.extend(instance.beams) sorted_completed = sorted(instance.completed, - key=lambda x: x.cum_logprob, + key=sort_beams_key, reverse=True) best_beams = sorted_completed[:beam_width] diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index f0aaf3733869d..6f1135f8093ba 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -184,7 +184,6 @@ class ChatCompletionRequest(OpenAIBaseModel): min_p: float = 0.0 repetition_penalty: float = 1.0 length_penalty: float = 1.0 - early_stopping: bool = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False ignore_eos: bool = False @@ -302,6 +301,7 @@ def to_beam_search_params(self, max_tokens=max_tokens, ignore_eos=self.ignore_eos, temperature=temperature, + length_penalty=self.length_penalty, ) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: @@ -345,12 +345,9 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: ignore_eos=self.ignore_eos, max_tokens=max_tokens, min_tokens=self.min_tokens, - use_beam_search=self.use_beam_search, - early_stopping=self.early_stopping, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - length_penalty=self.length_penalty, truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, @@ -518,7 +515,6 @@ class CompletionRequest(OpenAIBaseModel): min_p: float = 0.0 repetition_penalty: float = 1.0 length_penalty: float = 1.0 - early_stopping: bool = False stop_token_ids: Optional[List[int]] = Field(default_factory=list) include_stop_str_in_output: bool = False ignore_eos: bool = False @@ -597,6 +593,7 @@ def to_beam_search_params(self, max_tokens=max_tokens, ignore_eos=self.ignore_eos, temperature=temperature, + length_penalty=self.length_penalty, ) def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: @@ -641,13 +638,10 @@ def to_sampling_params(self, default_max_tokens: int) -> SamplingParams: ignore_eos=self.ignore_eos, max_tokens=max_tokens if not echo_without_generation else 1, min_tokens=self.min_tokens, - use_beam_search=self.use_beam_search, - early_stopping=self.early_stopping, prompt_logprobs=prompt_logprobs, skip_special_tokens=self.skip_special_tokens, spaces_between_special_tokens=self.spaces_between_special_tokens, include_stop_str_in_output=self.include_stop_str_in_output, - length_penalty=self.length_penalty, truncate_prompt_tokens=self.truncate_prompt_tokens, output_kind=RequestOutputKind.DELTA if self.stream \ else RequestOutputKind.FINAL_ONLY, diff --git a/vllm/envs.py b/vllm/envs.py index 0f46ac4f61fdf..d15cded416385 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -63,7 +63,6 @@ VLLM_TORCH_PROFILER_DIR: Optional[str] = None VLLM_USE_TRITON_AWQ: bool = False VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False - VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: bool = False VLLM_SKIP_P2P_CHECK: bool = False @@ -198,10 +197,6 @@ def get_default_config_root(): lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in ("true", "1")), - # If set, allowing the use of deprecated beam search implementation - "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH": - lambda: os.environ.get("VLLM_ALLOW_DEPRECATED_BEAM_SEARCH", "0") == "1", - # Internal flag to enable Dynamo graph capture "VLLM_TEST_DYNAMO_GRAPH_CAPTURE": lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")), diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index cfa857b8f9606..0b959da79c3be 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -947,8 +947,6 @@ def get_logprobs( # largest num logprobs in this API. If every logprobs is None, it will be # set to -1. largest_num_logprobs = -1 - # If beam search is enabled. - use_beam_search = False # Select indices to compute logprob from, ranks of token ids, and the top # k token ids from logprobs. @@ -981,8 +979,6 @@ def get_logprobs( largest_num_logprobs = max(largest_num_logprobs, sampling_params.logprobs) - use_beam_search = use_beam_search or sampling_params.use_beam_search - assert len(next_token_ids) == len(query_indices) if len(query_indices) == 0: @@ -995,7 +991,7 @@ def get_logprobs( # If largest_num_logprobs == -1, i.e. no logprobs are requested, we can # skip the whole logprob calculation. - if largest_num_logprobs >= 0 or use_beam_search: + if largest_num_logprobs >= 0: query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) next_token_ids_gpu = torch.tensor(next_token_ids, device=logprobs.device) @@ -1121,13 +1117,12 @@ def _get_sampled_logprob_if_needed( """Compute the sample logprob if needed.""" seq_ids = seq_group.seq_ids num_logprobs = seq_group.sampling_params.logprobs - use_beam_search = seq_group.sampling_params.use_beam_search sampled_logprobs: SampleLogprobs = [] next_token_ids, parent_seq_ids = sample_result if seq_group.do_sample: assert len(next_token_ids) > 0 - if num_logprobs is None and not use_beam_search: + if num_logprobs is None: for next_token_id in next_token_ids: # Use a dummy logprob sampled_logprobs.append({next_token_id: Logprob(inf)}) diff --git a/vllm/outputs.py b/vllm/outputs.py index 44cde6b561d85..4f29226aa5128 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -142,11 +142,7 @@ def from_seq_group(cls, seq_group: SequenceGroup, else: # Get the top-n sequences. n = sampling_params.n - if sampling_params.use_beam_search: - sorting_key = lambda seq: seq.get_beam_search_score( - sampling_params.length_penalty) - else: - sorting_key = lambda seq: seq.get_cumulative_logprob() + sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) top_n_seqs = sorted_seqs[:n] diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 19832c761d1a4..461b304635e19 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -11,7 +11,6 @@ from pydantic import BaseModel from typing_extensions import Annotated -import vllm.envs as envs from vllm.logger import init_logger from vllm.logits_process import LogitsProcessor, NoBadWordsLogitsProcessor @@ -25,7 +24,6 @@ class SamplingType(IntEnum): GREEDY = 0 RANDOM = 1 RANDOM_SEED = 2 - BEAM = 3 # maybe make msgspec? @@ -126,16 +124,6 @@ class SamplingParams( considered, relative to the probability of the most likely token. Must be in [0, 1]. Set to 0 to disable this. seed: Random seed to use for the generation. - use_beam_search: Whether to use beam search instead of sampling. - length_penalty: Float that penalizes sequences based on their length. - Used in beam search. - early_stopping: Controls the stopping condition for beam search. It - accepts the following values: `True`, where the generation stops as - soon as there are `best_of` complete candidates; `False`, where an - heuristic is applied and the generation stops when is it very - unlikely to find better candidates; `"never"`, where the beam search - procedure only stops when there cannot be better candidates - (canonical beam search algorithm). stop: List of strings that stop the generation when they are generated. The returned output will not contain the stop strings. stop_token_ids: List of tokens that stop the generation when they are @@ -189,9 +177,6 @@ class SamplingParams( top_k: int = -1 min_p: float = 0.0 seed: Optional[int] = None - use_beam_search: bool = False - length_penalty: float = 1.0 - early_stopping: Union[bool, str] = False stop: Optional[Union[str, List[str]]] = None stop_token_ids: Optional[List[int]] = None bad_words: Optional[List[int]] = None @@ -235,9 +220,6 @@ def from_optional( top_k: int = -1, min_p: float = 0.0, seed: Optional[int] = None, - use_beam_search: bool = False, - length_penalty: float = 1.0, - early_stopping: Union[bool, str] = False, stop: Optional[Union[str, List[str]]] = None, stop_token_ids: Optional[List[int]] = None, bad_words: Optional[List[int]] = None, @@ -278,9 +260,6 @@ def from_optional( top_k=top_k, min_p=min_p, seed=seed, - use_beam_search=use_beam_search, - length_penalty=length_penalty, - early_stopping=early_stopping, stop=stop, stop_token_ids=stop_token_ids, bad_words=bad_words, @@ -333,20 +312,13 @@ def __post_init__(self) -> None: self.output_text_buffer_length = max(len(s) for s in self.stop) - 1 self._verify_args() - if self.use_beam_search: - if not envs.VLLM_ALLOW_DEPRECATED_BEAM_SEARCH: - raise ValueError( - "Using beam search as a sampling parameter is deprecated, and will be removed in the future release. Please use the `vllm.LLM.use_beam_search` method for dedicated beam search instead, or set the environment variable `VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1` to suppress this error. For more details, see https://github.com/vllm-project/vllm/issues/8306 ." # noqa - ) - self._verify_beam_search() - else: - self._verify_non_beam_search() - if self.temperature < _SAMPLING_EPS: - # Zero temperature means greedy sampling. - self.top_p = 1.0 - self.top_k = -1 - self.min_p = 0.0 - self._verify_greedy_sampling() + + if self.temperature < _SAMPLING_EPS: + # Zero temperature means greedy sampling. + self.top_p = 1.0 + self.top_k = -1 + self.min_p = 0.0 + self._verify_greedy_sampling() # eos_token_id is added to this by the engine self._all_stop_token_ids = set(self.stop_token_ids) @@ -416,31 +388,6 @@ def _verify_args(self) -> None: RequestOutputKind.DELTA): raise ValueError("best_of must equal n to use output_kind=DELTA") - def _verify_beam_search(self) -> None: - if self.best_of == 1: - raise ValueError("best_of must be greater than 1 when using beam " - f"search. Got {self.best_of}.") - if self.temperature > _SAMPLING_EPS: - raise ValueError("temperature must be 0 when using beam search.") - if self.top_p < 1.0 - _SAMPLING_EPS: - raise ValueError("top_p must be 1 when using beam search.") - if self.top_k != -1: - raise ValueError("top_k must be -1 when using beam search.") - if self.early_stopping not in [True, False, "never"]: - raise ValueError( - f"early_stopping must be True, False, or 'never', " - f"got {self.early_stopping}.") - - def _verify_non_beam_search(self) -> None: - if self.early_stopping is not False: - raise ValueError("early_stopping is not effective and must be " - "False when not using beam search.") - if (self.length_penalty < 1.0 - _SAMPLING_EPS - or self.length_penalty > 1.0 + _SAMPLING_EPS): - raise ValueError( - "length_penalty is not effective and must be the " - "default value of 1.0 when not using beam search.") - def _verify_greedy_sampling(self) -> None: assert isinstance(self.best_of, int) if self.best_of > 1: @@ -485,8 +432,6 @@ def update_from_generation_config( @cached_property def sampling_type(self) -> SamplingType: - if self.use_beam_search: - return SamplingType.BEAM if self.temperature < _SAMPLING_EPS: return SamplingType.GREEDY if self.seed is not None: @@ -523,9 +468,6 @@ def __repr__(self) -> str: f"top_k={self.top_k}, " f"min_p={self.min_p}, " f"seed={self.seed}, " - f"use_beam_search={self.use_beam_search}, " - f"length_penalty={self.length_penalty}, " - f"early_stopping={self.early_stopping}, " f"stop={self.stop}, " f"stop_token_ids={self.stop_token_ids}, " f"include_stop_str_in_output={self.include_stop_str_in_output}, " @@ -551,3 +493,4 @@ class BeamSearchParams( max_tokens: int ignore_eos: bool = False temperature: float = 0.0 + length_penalty: float = 1.0 diff --git a/vllm/sequence.py b/vllm/sequence.py index 781bcedde2b52..9116408a001ff 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -577,25 +577,6 @@ def get_output_token_ids(self) -> Tuple[int, ...]: def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob - def get_beam_search_score(self, - length_penalty: float = 1.0, - seq_len: Optional[int] = None, - eos_token_id: Optional[int] = None) -> float: - """Calculate the beam search score with length penalty. - - Adapted from - - https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 - """ - if seq_len is None: - seq_len = self.get_len() - # NOTE: HF implementation does not count the EOS token - # towards the length, we align with that here for testing. - if (eos_token_id is not None - and self.get_last_token_id() == eos_token_id): - seq_len -= 1 - return self.get_cumulative_logprob() / (seq_len**length_penalty) - def is_finished(self) -> bool: return SequenceStatus.is_finished(self.status) @@ -809,25 +790,18 @@ def set_finished_time(self, time: Optional[float]) -> None: def get_max_num_running_seqs(self) -> int: """The maximum number of sequences running in parallel in the remaining lifetime of the request.""" - if self.sampling_params and self.sampling_params.use_beam_search: - # For beam search, maximally there will always be `best_of` beam - # candidates running in the future. + if self.sampling_params: best_of = self.sampling_params.best_of assert isinstance(best_of, int) - return best_of - else: - if self.sampling_params: - best_of = self.sampling_params.best_of - assert isinstance(best_of, int) - if best_of > self.num_seqs(): - # At prompt stage, the sequence group is not yet filled up - # and only have one sequence running. However, in the - # generation stage, we will have `best_of` sequences - # running. - return best_of - # At sampling stages, return the number of actual sequences - # that are not finished yet. - return self.num_unfinished_seqs() + if best_of > self.num_seqs(): + # At prompt stage, the sequence group is not yet filled up + # and only have one sequence running. However, in the + # generation stage, we will have `best_of` sequences + # running. + return best_of + # At sampling stages, return the number of actual sequences + # that are not finished yet. + return self.num_unfinished_seqs() def get_seqs( self, diff --git a/vllm/utils.py b/vllm/utils.py index e44365fa24990..1b7638c4a12ac 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1361,3 +1361,22 @@ def dec(self, num=1): @property def value(self): return self._value + + +def get_beam_search_score( + tokens: List[int], + cumulative_logprob: float, + eos_token_id: int, + length_penalty: float = 1.0, +) -> float: + """Calculate the beam search score with length penalty. + + Adapted from + + https://github.com/huggingface/transformers/blob/ccb92be23def445f2afdea94c31286f84b89eb5b/src/transformers/generation/beam_search.py#L938 + """ + seq_len = len(tokens) + if tokens[-1] == eos_token_id: + seq_len -= 1 + + return cumulative_logprob / (seq_len**length_penalty) diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 2472ac25aee44..12e4215038d74 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -453,9 +453,6 @@ def _prepare_sample( f"Best of > {_MAX_NUM_SAMPLES} is not supported by the TPU " "backend.") best_of.append(sampling_params.best_of) - if sampling_params.use_beam_search: - raise NotImplementedError( - "Beam search is not supported by the TPU backend.") if sampling_params.logprobs is not None: raise NotImplementedError( "logprobs is not currently supported by the TPU backend.")