From 505e60c5c4d648d461227192ff0c14c64898d1c9 Mon Sep 17 00:00:00 2001 From: ApostaC Date: Tue, 10 Dec 2024 17:37:00 +0000 Subject: [PATCH] Updating the benchmark script with correct usage instructions Signed-off-by: ApostaC --- benchmarks/benchmark_long_document_qa.py | 138 ++++++++++++++++++++--- 1 file changed, 120 insertions(+), 18 deletions(-) diff --git a/benchmarks/benchmark_long_document_qa.py b/benchmarks/benchmark_long_document_qa.py index 82e37aaccef96..b5ec21bfd0fba 100644 --- a/benchmarks/benchmark_long_document_qa.py +++ b/benchmarks/benchmark_long_document_qa.py @@ -2,35 +2,121 @@ Benchmark the efficiency of prefix caching. This script allows you to benchmark the performance of -a model with and without prefix caching using either fixed prompts -or prompts sampled from the ShareGPT dataset. +a model with prefix-caching or cpu-offloading using fixed prompts Fixed example usage: - python benchmark_prefix_caching.py \ + # This command run the vllm with 50GB CPU memory for offloading + # The workload samples 8 different prompts with a default input + # length of 20010 tokens, then replicates each prompt 2 times. + python benchmark_long_document_qa.py \ --model meta-llama/Llama-2-7b-chat-hf \ --enable-prefix-caching \ - --num-prompts 1 \ - --repeat-count 100 - -ShareGPT example usage: - # This command samples 20 prompts with input lengths - # between 128 and 256 tokens from the ShareGPT dataset, - # then replicates each prompt 5 times. - python benchmark_prefix_caching.py \ - --model meta-llama/Llama-2-7b-chat-hf \ - --dataset-path /path/to/ShareGPT_V3_unfiltered_cleaned_split.json \ - --enable-prefix-caching \ - --num-prompts 20 \ - --repeat-count 5 \ - --input-length-range 128:256 + --block-allocator CpuOffloadingBlockAllocator \ + --num-documents 8 \ + --repeat-count 2 \ + --cpu-memory-gb 50 + +Commandline arguments: + + # Basic arguments + --model: The model to use for the benchmark. + + --enable-prefix-caching: Enable prefix caching or not. + + --block-allocator: The block allocator that vLLM uses. + - CpuGpuBlockAllocator: The default block allocator. + - CpuOffloadingBlockAllocator: The block allocator that supports + cpu offloading + + --gpu-memory-utilization: GPU memory utilization for vLLM. + + --cpu-memory-gb: The amount of CPU memory (GB) that is used by vLLM. + NOTE: CPU memory should be larger than GPU KV cache size when + using CpuOffloadingBlockAllocator. + + # Workload-related arguments + --num-documents: The number of documents to sample prompts from. + + --repeat-count: The number of times to repeat each prompt. + + # Other functionality + --seed: Random seed for reproducibility. + + --profile-swap-blocks: Profile the swap_blocks function in the custom ops. """ import random import time +import torch + from vllm import LLM, SamplingParams from vllm.utils import FlexibleArgumentParser +""" +HELPER FUNCTIONS FOR PROFILING +""" +execution_times = {} + + +def build_result_dict(start_time, end_time, *args): + total_time = end_time - start_time + length = -1 + if len(args) > 1 and isinstance(args[1], torch.Tensor): + length = len(args[1]) + + return { + "start_time": start_time, + "total_time": total_time, + "swap_len": length + } + + +def timing_decorator(func): + + def wrapper(*args, **kwargs): + global execution_times + torch.cuda.synchronize() + start_time = time.time() # Record the start time + result = func(*args, **kwargs) # Call the wrapped function + torch.cuda.synchronize() + end_time = time.time() # Record the end time + if func.__name__ not in execution_times: + execution_times[func.__name__] = [] + + res = build_result_dict(start_time, end_time, *args) + execution_times[func.__name__].append(res) + return result # Return the result of the original function + + return wrapper + + +def process_timing_results(): + global execution_times + for key in execution_times: + len_to_time = {} + len_to_count = {} + for item in execution_times[key]: + swap_len = item["swap_len"] + if swap_len not in len_to_time: + len_to_time[swap_len] = 0 + len_to_time[swap_len] += item["total_time"] + + if swap_len not in len_to_count: + len_to_count[swap_len] = 0 + len_to_count[swap_len] += 1 + + for swap_len in len_to_time: + total_time = len_to_time[swap_len] + count = len_to_count[swap_len] + print(f"{key} on {swap_len} pages: " + f"{(count * swap_len) / total_time} pages per second") + + +""" +MAIN FUNCTIONS FOR BENCHMARKING +""" + def test_long_document_qa(llm=None, sampling_params=None, prompts=None): @@ -47,6 +133,10 @@ def repeat_prompts(prompts, repeat_count): def main(args): + if args.profile_swap_blocks: + from vllm.worker.cache_engine import CacheEngine + CacheEngine.swap_out = timing_decorator(CacheEngine.swap_out) + CacheEngine.swap_in = timing_decorator(CacheEngine.swap_in) random.seed(args.seed) @@ -72,6 +162,7 @@ def main(args): block_allocator=args.block_allocator, preemption_mode=preemption_mode, swap_space=args.cpu_memory_gb, + enable_chunked_prefill=False, gpu_memory_utilization=args.gpu_memory_utilization, max_model_len=30000) @@ -86,6 +177,8 @@ def main(args): sampling_params=sampling_params, ) + random.shuffle(prompts) + print("------start generating------") test_long_document_qa( llm=llm, @@ -93,6 +186,9 @@ def main(args): sampling_params=sampling_params, ) + if args.profile_swap_blocks: + process_timing_results() + if __name__ == "__main__": parser = FlexibleArgumentParser( @@ -136,7 +232,7 @@ def main(args): help='Random seed for reproducibility') parser.add_argument('--gpu-memory-utilization', type=float, - default=0.5, + default=0.9, help='GPU memory utilization for vLLM. Should be a ' 'float point number ranging from 0 to 1. For this ' 'test please use a small value so that the GPU ' @@ -160,5 +256,11 @@ def main(args): 'supports offloading the KV cache to CPU . ' 'When using CpuOffloadingBlockAllocator, the ' 'preemption mode must be recompute.') + + parser.add_argument( + '--profile-swap-blocks', + action='store_true', + help='Profile the swap_blocks function in the custom ops') + args = parser.parse_args() main(args)