diff --git a/demo_temporary/__init__.py b/demo_temporary/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/benchmarks/__init__.py b/demo_temporary/benchmarks/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/benchmarks/benchmark_attention_impl.py b/demo_temporary/benchmarks/benchmark_attention_impl.py new file mode 100644 index 0000000000000..f1da13ed8fb92 --- /dev/null +++ b/demo_temporary/benchmarks/benchmark_attention_impl.py @@ -0,0 +1,102 @@ +import os +import random +import time + + +def benchmark_vllm(args): + random.seed(args.seed) + os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_impl + + import gc + + import torch + + from vllm.wde.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + from vllm.wde.entrypoints.llm import LLMEngine + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs(model=args.model, + tokenizer=args.tokenizer, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + n_step = 0 + while engine.has_unfinished_requests(): + engine.step() + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + engine.executor.shutdown_execute_loop() + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + args = edict() + + args.input_len = 256 + args.num_prompts = 10000 + + args.model = 'BAAI/bge-m3' + + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.max_model_len = None + args.device = "cuda" + args.batchsize = [1, 2, 4, 8, 16, 32, 64] + args.scheduling = "double_buffer" + + from concurrent.futures import ProcessPoolExecutor + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_vllm, args) + f.result() + + AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] + AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" + ] + AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" + ] + + AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, + } + + for dtype, attention_impls in AttentionImpls.items(): + print("dtype:", dtype) + for attention_impl in attention_impls: + print("attention_impl:", attention_impl) + args.attention_impl = attention_impl + args.dtype = dtype + run_vllm(args) diff --git a/demo_temporary/benchmarks/benchmark_bge-m3.py b/demo_temporary/benchmarks/benchmark_bge-m3.py new file mode 100644 index 0000000000000..44b84cc7c5bbc --- /dev/null +++ b/demo_temporary/benchmarks/benchmark_bge-m3.py @@ -0,0 +1,119 @@ +import random +import time + + +def benchmark_hf(args): + random.seed(args.seed) + + import torch + from FlagEmbedding import BGEM3FlagModel + + model = BGEM3FlagModel(args.model, use_fp16=True) + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + with torch.no_grad(): + for batchsize in args.batchsize: + start = time.perf_counter() + n_step = 0 + for i in range(0, len(requests), batchsize): + batch = requests[i:i + batchsize] + model.encode(batch, batch_size=batchsize) + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + +def benchmark_vllm(args): + random.seed(args.seed) + + import gc + + import torch + + from vllm.wde.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + from vllm.wde.entrypoints.llm import LLMEngine + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs(model=args.model, + tokenizer=args.tokenizer, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + n_step = 0 + while engine.has_unfinished_requests(): + engine.step() + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + engine.executor.shutdown_execute_loop() + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + args = edict() + + args.input_len = 256 + args.num_prompts = 10000 + + args.model = 'BAAI/bge-m3' + + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.max_model_len = None + args.dtype = "half" + args.device = "cuda" + args.batchsize = [1, 2, 4, 8, 16, 32, 64] + + from concurrent.futures import ProcessPoolExecutor + + def run_hf(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_hf, args) + f.result() + + run_hf(args) + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_vllm, args) + f.result() + + for scheduling in ["sync", "async", "double_buffer"]: + print(scheduling) + args.scheduling = scheduling + run_vllm(args) diff --git a/demo_temporary/benchmarks/benchmark_data_parallelism.py b/demo_temporary/benchmarks/benchmark_data_parallelism.py new file mode 100644 index 0000000000000..29c25378cb1b4 --- /dev/null +++ b/demo_temporary/benchmarks/benchmark_data_parallelism.py @@ -0,0 +1,89 @@ +import random +import time + + +def benchmark(args): + random.seed(args.seed) + + import gc + + import torch + + from vllm.wde.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + from vllm.wde.entrypoints.llm import LLMEngine + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs(model=args.model, + tokenizer=args.tokenizer, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling, + data_parallel_size=args.data_parallel_size) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + engine.executor.ensure_start_execute_loop() + + # Because each thread has to load the model separately, + # the loading may not be completed here. + # If it is run only once, the measured data parallel speed will be low. + + for i in range(3): + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + while engine.has_unfinished_requests(): + engine.step() + end = time.perf_counter() + elapsed_time = end - start + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s") + + engine.executor.shutdown_execute_loop() + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + args = edict() + + args.input_len = 256 + args.num_prompts = 10000 + + args.model = 'BAAI/bge-m3' + + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.max_model_len = None + args.dtype = "half" + args.device = "cuda" + args.batchsize = [1, 2, 4, 8, 16] + args.max_data_parallel_size = 1 + + from concurrent.futures import ProcessPoolExecutor + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark, args) + f.result() + + for scheduling in ["async", "double_buffer"]: + for data_parallel_size in range(args.max_data_parallel_size + 1): + print("scheduling:", scheduling, "data_parallel_size", + data_parallel_size) + args.data_parallel_size = data_parallel_size + args.scheduling = scheduling + run_vllm(args) diff --git a/demo_temporary/benchmarks/benchmark_xlm-roberta.py b/demo_temporary/benchmarks/benchmark_xlm-roberta.py new file mode 100644 index 0000000000000..4629070dc5552 --- /dev/null +++ b/demo_temporary/benchmarks/benchmark_xlm-roberta.py @@ -0,0 +1,132 @@ +import random +import time + + +def benchmark_hf(args): + random.seed(args.seed) + + import torch + from transformers import AutoModelForMaskedLM, AutoTokenizer + + from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[args.dtype] + + tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) + model = AutoModelForMaskedLM.from_pretrained(args.model, + torch_dtype=torch_dtype).to( + args.device) + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + with torch.no_grad(): + for batchsize in args.batchsize: + start = time.perf_counter() + n_step = 0 + for i in range(0, len(requests), batchsize): + batch = requests[i:i + batchsize] + encoded_input = tokenizer(batch, + return_tensors='pt').to(args.device) + model(**encoded_input) + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + +def benchmark_vllm(args): + random.seed(args.seed) + + import gc + + import torch + + from vllm.wde.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + from vllm.wde.entrypoints.llm import LLMEngine + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs( + model=args.model, + tokenizer=args.tokenizer, + quantization=args.quantization, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + quantization_param_path=args.quantization_param_path, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + n_step = 0 + while engine.has_unfinished_requests(): + engine.step() + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + engine.executor.shutdown_execute_loop() + gc.collect() + torch.cuda.empty_cache() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + args = edict() + + args.input_len = 256 + args.num_prompts = 10000 + + args.model = 'FacebookAI/xlm-roberta-base' + #args.model = 'FacebookAI/xlm-roberta-large' + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.quantization = None + args.quantization_param_path = None + args.max_model_len = None + + args.dtype = "half" + args.device = "cuda" + args.batchsize = [1, 2, 4, 8, 16, 32, 64] + from concurrent.futures import ProcessPoolExecutor + + def run_hf(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_hf, args) + f.result() + + run_hf(args) + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_vllm, args) + f.result() + + for scheduling in ["sync", "async", "double_buffer"]: + print(scheduling) + args.scheduling = scheduling + run_vllm(args) diff --git a/demo_temporary/examples/__init__.py b/demo_temporary/examples/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/examples/offline_inference_bert.py b/demo_temporary/examples/offline_inference_bert.py new file mode 100644 index 0000000000000..87c7d3bbaff8d --- /dev/null +++ b/demo_temporary/examples/offline_inference_bert.py @@ -0,0 +1,14 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model="google-bert/bert-base-uncased") + +outputs = llm.encode(prompts) +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_bge-m3.py b/demo_temporary/examples/offline_inference_bge-m3.py new file mode 100644 index 0000000000000..f806c607fc85f --- /dev/null +++ b/demo_temporary/examples/offline_inference_bge-m3.py @@ -0,0 +1,15 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model='BAAI/bge-m3') + +outputs = llm.encode(prompts) + +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_bge-reranker-v2-m3.py b/demo_temporary/examples/offline_inference_bge-reranker-v2-m3.py new file mode 100644 index 0000000000000..608c523856945 --- /dev/null +++ b/demo_temporary/examples/offline_inference_bge-reranker-v2-m3.py @@ -0,0 +1,14 @@ +from vllm.wde.entrypoints.llm import LLM + +pairs = [['query', 'passage'], ['what is panda?', 'hi'], + [ + 'what is panda?', 'The giant panda (Ailuropoda melanoleuca), ' + 'sometimes called a panda bear or simply panda, ' + 'is a bear species endemic to China.' + ]] + +llm = LLM(model="BAAI/bge-reranker-v2-m3") + +outputs = llm.reranker(pairs) +for output in outputs: + print(output.score) diff --git a/demo_temporary/examples/offline_inference_bge-v1-5.py b/demo_temporary/examples/offline_inference_bge-v1-5.py new file mode 100644 index 0000000000000..ef5b42f6a2cee --- /dev/null +++ b/demo_temporary/examples/offline_inference_bge-v1-5.py @@ -0,0 +1,15 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model='BAAI/bge-large-zh-v1.5') + +outputs = llm.encode(prompts) + +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_data_parallelism.py b/demo_temporary/examples/offline_inference_data_parallelism.py new file mode 100644 index 0000000000000..1595c45f1306e --- /dev/null +++ b/demo_temporary/examples/offline_inference_data_parallelism.py @@ -0,0 +1,15 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model='BAAI/bge-m3', data_parallel_size=1) + +outputs = llm.encode(prompts) + +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_e5-mistral-7b.py b/demo_temporary/examples/offline_inference_e5-mistral-7b.py new file mode 100644 index 0000000000000..3a1f5e8479062 --- /dev/null +++ b/demo_temporary/examples/offline_inference_e5-mistral-7b.py @@ -0,0 +1,15 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model='intfloat/e5-mistral-7b-instruct') + +outputs = llm.encode(prompts) + +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_gte-Qwen2.py b/demo_temporary/examples/offline_inference_gte-Qwen2.py new file mode 100644 index 0000000000000..379474e3752c2 --- /dev/null +++ b/demo_temporary/examples/offline_inference_gte-Qwen2.py @@ -0,0 +1,15 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +# You should use it like this +llm = LLM(model="Alibaba-NLP/gte-Qwen2-7B-instruct") + +outputs = llm.encode(prompts) +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_output_last_hidden_states.py b/demo_temporary/examples/offline_inference_output_last_hidden_states.py new file mode 100644 index 0000000000000..3ea8012da5c00 --- /dev/null +++ b/demo_temporary/examples/offline_inference_output_last_hidden_states.py @@ -0,0 +1,14 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model="Qwen/Qwen2-0.5B-Instruct", output_last_hidden_states=True) + +outputs = llm.encode(prompts) +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_snowflake-arctic-embed.py b/demo_temporary/examples/offline_inference_snowflake-arctic-embed.py new file mode 100644 index 0000000000000..dd56d80d6bdc9 --- /dev/null +++ b/demo_temporary/examples/offline_inference_snowflake-arctic-embed.py @@ -0,0 +1,15 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model='Snowflake/snowflake-arctic-embed-xs') + +outputs = llm.encode(prompts) + +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/examples/offline_inference_xlm-roberta.py b/demo_temporary/examples/offline_inference_xlm-roberta.py new file mode 100644 index 0000000000000..d39d55f3193c8 --- /dev/null +++ b/demo_temporary/examples/offline_inference_xlm-roberta.py @@ -0,0 +1,14 @@ +from vllm.wde.entrypoints.llm import LLM + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + +llm = LLM(model="FacebookAI/xlm-roberta-base") + +outputs = llm.encode(prompts) +for output in outputs: + print(output.outputs.shape) diff --git a/demo_temporary/profiler/__init__.py b/demo_temporary/profiler/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/demo_temporary/profiler/encode_only_async.py b/demo_temporary/profiler/encode_only_async.py new file mode 100644 index 0000000000000..092741631b7e5 --- /dev/null +++ b/demo_temporary/profiler/encode_only_async.py @@ -0,0 +1,115 @@ +import random +import time + + +def patch(): + from vllm.wde.prefill_only.executor.gpu_executor import GPUAsyncExecutor + + simple_execute_loop = GPUAsyncExecutor.simple_execute_loop + + def p_execute_loop(self, *args, **kwargs): + import torch + with torch.profiler.profile(activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ]) as prof: + simple_execute_loop(self, *args, **kwargs) + + prof.export_chrome_trace("simple_execute_loop.json") + + GPUAsyncExecutor.simple_execute_loop = p_execute_loop + + double_buffer_execute_loop = GPUAsyncExecutor.double_buffer_execute_loop + + def p_execute_loop(self, *args, **kwargs): + import torch + with torch.profiler.profile(activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ]) as prof: + double_buffer_execute_loop(self, *args, **kwargs) + prof.export_chrome_trace("double_buffer_execute_loop.json") + + GPUAsyncExecutor.double_buffer_execute_loop = p_execute_loop + + +def benchmark_vllm(args): + random.seed(args.seed) + patch() + + from vllm.wde.encode_only.arg_utils import ( # noqa: E501 + EncodeOnlyEngineArgs as EngineArgs) + from vllm.wde.entrypoints.llm import LLMEngine + + prompt = "if" * args.input_len + requests = [prompt for _ in range(args.num_prompts)] + + engine_args = EngineArgs( + model=args.model, + tokenizer=args.tokenizer, + quantization=args.quantization, + seed=args.seed, + trust_remote_code=args.trust_remote_code, + dtype=args.dtype, + max_model_len=args.max_model_len, + quantization_param_path=args.quantization_param_path, + device=args.device, + max_num_seqs=32, + scheduling=args.scheduling) + + engine = LLMEngine.from_engine_args(engine_args) + + for batchsize in args.batchsize: + engine.engine_config.scheduler_config.set_args(max_num_seqs=batchsize) + + start = time.perf_counter() + for request_id, prompt in enumerate(requests): + engine.add_request(str(request_id), prompt) + + n_step = 0 + while engine.has_unfinished_requests(): + engine.step() + n_step += 1 + end = time.perf_counter() + + elapsed_time = end - start + delay = elapsed_time / n_step + + print(f"Batchsize {batchsize}, Throughput: " + f"{len(requests) / elapsed_time:.4f} requests/s, " + f"Delay {delay * 1000:0.2f} ms, n_step {n_step}") + + engine.executor.shutdown_execute_loop() + + +if __name__ == '__main__': + from easydict import EasyDict as edict + args = edict() + + args.input_len = 256 + args.num_prompts = 100 + + args.model = 'BAAI/bge-m3' + + args.trust_remote_code = False + args.tokenizer = args.model + args.seed = 0 + args.quantization = None + args.quantization_param_path = None + args.max_model_len = None + + args.dtype = "half" + args.device = "cuda" + args.batchsize = [4] + + from concurrent.futures import ProcessPoolExecutor + + def run_vllm(args): + with ProcessPoolExecutor(1) as executor: + f = executor.submit(benchmark_vllm, args) + f.result() + + for scheduling in ["async", "double_buffer"]: + print(scheduling) + args.scheduling = scheduling + run_vllm(args) diff --git a/tests/wde/__init__.py b/tests/wde/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/decode_only/__init__.py b/tests/wde/decode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/decode_only/attention_impl/__init__.py b/tests/wde/decode_only/attention_impl/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/decode_only/attention_impl/basic_correctness.py b/tests/wde/decode_only/attention_impl/basic_correctness.py new file mode 100644 index 0000000000000..b547b73fc9512 --- /dev/null +++ b/tests/wde/decode_only/attention_impl/basic_correctness.py @@ -0,0 +1,105 @@ +import itertools as it +import random +from typing import List, TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature + +from tests.wde.utils import HfRunner, VllmRunner, compare_embeddings + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class Qwen2HfRunner(HfRunner): + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + outputs = self.model(**encoded_input, output_hidden_states=True) + + last_hidden_states = outputs.hidden_states[-1] + seq_len = encoded_input.attention_mask.sum(axis=1) + + last_hidden_states_list = [] + for e, s in zip(last_hidden_states, seq_len): + last_hidden_states_list.append(e[s - 1]) + return last_hidden_states_list + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["Qwen/Qwen2-0.5B-Instruct"] + +AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] +AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] +AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] + +AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, +} + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("scheduling", ["sync"]) +@torch.inference_mode +def test_basic_correctness_fp16( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + attention_impls = AttentionImpls[dtype] + + impl_outputs_list = [] + + for attention_impl in attention_impls: + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling, + attention_impl=attention_impl, + output_last_hidden_states=True) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + impl_outputs_list.append((attention_impl, vllm_outputs)) + + tolerance = 1e-2 + for a, b in it.combinations(impl_outputs_list, 2): + similarities = compare_embeddings(a[1], b[1]) + all_similarities = torch.stack(similarities) + + assert torch.all( + (all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0" diff --git a/tests/wde/decode_only/models/__init__.py b/tests/wde/decode_only/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/decode_only/models/test_output_last_hidden_states.py b/tests/wde/decode_only/models/test_output_last_hidden_states.py new file mode 100644 index 0000000000000..b704635b9b8dd --- /dev/null +++ b/tests/wde/decode_only/models/test_output_last_hidden_states.py @@ -0,0 +1,89 @@ +import random +from typing import List, TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, BatchEncoding, BatchFeature + +from tests.wde.utils import HfRunner, VllmRunner, compare_embeddings + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class Qwen2HfRunner(HfRunner): + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + outputs = self.model(**encoded_input, output_hidden_states=True) + + last_hidden_states = outputs.hidden_states[-1] + seq_len = encoded_input.attention_mask.sum(axis=1) + + last_hidden_states_list = [] + for e, s in zip(last_hidden_states, seq_len): + last_hidden_states_list.append(e[s - 1]) + return last_hidden_states_list + + +@pytest.fixture(scope="session") +def hf_runner(): + return Qwen2HfRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["Qwen/Qwen2-0.5B-Instruct"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForCausalLM) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling, + output_last_hidden_states=True) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/encode_only/__init__.py b/tests/wde/encode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/encode_only/attention_impl/__init__.py b/tests/wde/encode_only/attention_impl/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/encode_only/attention_impl/basic_correctness.py b/tests/wde/encode_only/attention_impl/basic_correctness.py new file mode 100644 index 0000000000000..ce1e301194035 --- /dev/null +++ b/tests/wde/encode_only/attention_impl/basic_correctness.py @@ -0,0 +1,86 @@ +import itertools as it +import random +from typing import TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature + +from tests.wde.utils import VllmRunner, compare_embeddings + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["BAAI/bge-m3"] + +AttentionImpls_fp32 = ["TORCH_SDPA", "XFORMERS", "TORCH_NAIVE"] +AttentionImpls_fp16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] +AttentionImpls_bf16 = [ + "FLASH_ATTN", "TORCH_SDPA", "XFORMERS", "FLASHINFER", "TORCH_NAIVE" +] + +AttentionImpls = { + "float": AttentionImpls_fp32, + "half": AttentionImpls_fp16, + "bfloat16": AttentionImpls_bf16, +} + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float", "half", "bfloat16"]) +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("scheduling", ["sync"]) +@torch.inference_mode +def test_basic_correctness_fp16( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + attention_impls = AttentionImpls[dtype] + + impl_outputs_list = [] + + for attention_impl in attention_impls: + with vllm_runner( + model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling, + attention_impl=attention_impl, + ) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + impl_outputs_list.append((attention_impl, vllm_outputs)) + + tolerance = 1e-2 + for a, b in it.combinations(impl_outputs_list, 2): + similarities = compare_embeddings(a[1], b[1]) + all_similarities = torch.stack(similarities) + + assert torch.all( + (all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"{a[0]} vs {b[0]}, not all values are within {tolerance} of 1.0" diff --git a/tests/wde/encode_only/models/__init__.py b/tests/wde/encode_only/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/encode_only/models/test_bert.py b/tests/wde/encode_only/models/test_bert.py new file mode 100644 index 0000000000000..1e3290c2ddbf8 --- /dev/null +++ b/tests/wde/encode_only/models/test_bert.py @@ -0,0 +1,80 @@ +import random +from typing import List, TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature, BertModel + +from tests.wde.utils import HfRunner, VllmRunner, compare_embeddings + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class BertHfRunner(HfRunner): + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + outputs = self.model(**encoded_input).pooler_output + return outputs + + +@pytest.fixture(scope="session") +def hf_runner(): + return BertHfRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["google-bert/bert-base-uncased"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype, auto_cls=BertModel) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/encode_only/models/test_xlm-roberta.py b/tests/wde/encode_only/models/test_xlm-roberta.py new file mode 100644 index 0000000000000..caaa7263386d2 --- /dev/null +++ b/tests/wde/encode_only/models/test_xlm-roberta.py @@ -0,0 +1,67 @@ +import random +from typing import TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature + +from tests.wde.utils import HfRunner, VllmRunner, compare_embeddings + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["FacebookAI/xlm-roberta-base", "FacebookAI/xlm-roberta-large"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/reranker/__init__.py b/tests/wde/reranker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/reranker/models/__init__.py b/tests/wde/reranker/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/reranker/models/test_bge-reranker-v2-m3.py b/tests/wde/reranker/models/test_bge-reranker-v2-m3.py new file mode 100644 index 0000000000000..17fe7ab3dd103 --- /dev/null +++ b/tests/wde/reranker/models/test_bge-reranker-v2-m3.py @@ -0,0 +1,113 @@ +import random +from typing import List, TypeVar + +import numpy as np +import pytest +import torch +import torch.nn as nn +from transformers import (AutoModelForSequenceClassification, BatchEncoding, + BatchFeature) + +from tests.wde.utils import HfRunner, VllmRunner, cleanup +from vllm.wde.reranker.schema.engine_io import RerankerInputs + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class VllmRerankerRunner(VllmRunner): + + def reranker(self, inputs: RerankerInputs) -> List[float]: + req_outputs = self.model.reranker(inputs) + outputs = [] + for req_output in req_outputs: + score = req_output.score + outputs.append(score) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +class HfRerankerRunner(HfRunner): + + def reranker(self, inputs: RerankerInputs) -> List[float]: + encoded_input = self.tokenizer(inputs, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + scores = self.model(**encoded_input).logits.view(-1, ) + return scores.cpu().numpy().tolist() + + +@pytest.fixture(scope="session") +def hf_runner(): + return HfRerankerRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRerankerRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + pairs = [ + ["query", "passage"], + ["what is panda?", "hi"], + [ + "what is panda?", + "The giant panda (Ailuropoda melanoleuca), " + "sometimes called a panda bear or simply panda, " + "is a bear species endemic to China.", + ], + ] * 11 + random.shuffle(pairs) + return pairs + + +MODELS = ["BAAI/bge-reranker-v2-m3"] + + +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.reranker(example_prompts) + + with vllm_runner(model, dtype=dtype, + max_num_seqs=max_num_seqs) as vllm_model: + vllm_outputs = vllm_model.reranker(example_prompts) + + # Without using sigmoid, + # the difference may be greater than 1e-2, resulting in flakey test + hf_outputs = [sigmoid(x) for x in hf_outputs] + vllm_outputs = [sigmoid(x) for x in vllm_outputs] + + all_similarities = np.array(hf_outputs) - np.array(vllm_outputs) + + tolerance = 1e-2 + assert np.all((all_similarities <= tolerance) + & (all_similarities >= -tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/retriever/__init__.py b/tests/wde/retriever/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/retriever/models/__init__.py b/tests/wde/retriever/models/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/tests/wde/retriever/models/test_bge-m3.py b/tests/wde/retriever/models/test_bge-m3.py new file mode 100644 index 0000000000000..5fb4375d7ec7e --- /dev/null +++ b/tests/wde/retriever/models/test_bge-m3.py @@ -0,0 +1,108 @@ +import random +from typing import List, TypeVar + +import numpy as np +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature + +from tests.wde.utils import VllmRunner, cleanup, compare_embeddings_np, is_cpu + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class FlagEmbeddingRunner: + + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, "device") and input.device.type == "cuda": + return input # Already on GPU, no need to move + return input.to("cuda") + else: + # Check if the input is already on the CPU + if hasattr(input, "device") and input.device.type == "cpu": + return input # Already on CPU, no need to move + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + ) -> None: + # depend on FlagEmbedding peft + from FlagEmbedding import BGEM3FlagModel + + self.model_name = model_name + model = BGEM3FlagModel(self.model_name, use_fp16=dtype == "half") + self.model = model + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + output = self.model.encode(prompts) + return output["dense_vecs"] + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def hf_runner(): + return FlagEmbeddingRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ["BAAI/bge-m3"] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = [t.cpu().numpy() for t in vllm_outputs] + + similarities = compare_embeddings_np(hf_outputs, vllm_outputs) + all_similarities = np.stack(similarities) + tolerance = 1e-2 + assert np.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/retriever/models/test_bge-v1-5.py b/tests/wde/retriever/models/test_bge-v1-5.py new file mode 100644 index 0000000000000..6d27eb41e8cca --- /dev/null +++ b/tests/wde/retriever/models/test_bge-v1-5.py @@ -0,0 +1,108 @@ +import random +from typing import List, TypeVar + +import numpy as np +import pytest +import torch +import torch.nn as nn +from transformers import BatchEncoding, BatchFeature + +from tests.wde.utils import VllmRunner, cleanup, compare_embeddings_np, is_cpu + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class FlagEmbeddingRunner: + + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, "device") and input.device.type == "cuda": + return input # Already on GPU, no need to move + return input.to("cuda") + else: + # Check if the input is already on the CPU + if hasattr(input, "device") and input.device.type == "cpu": + return input # Already on CPU, no need to move + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + ) -> None: + # depend on FlagEmbedding peft + from FlagEmbedding import FlagModel + + self.model_name = model_name + model = FlagModel(self.model_name, use_fp16=dtype == "half") + self.model = model + + @torch.inference_mode + def encode(self, prompts: List[str]) -> np.ndarray: + embeddings = self.model.encode(prompts) + return embeddings + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +@pytest.fixture(scope="session") +def hf_runner(): + return FlagEmbeddingRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ['BAAI/bge-large-zh-v1.5', 'BAAI/bge-base-en-v1.5'] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + vllm_outputs = [t.cpu().numpy() for t in vllm_outputs] + + similarities = compare_embeddings_np(hf_outputs, vllm_outputs) + all_similarities = np.stack(similarities) + tolerance = 1e-2 + assert np.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/retriever/models/test_e5-mistral-7b-instruct.py b/tests/wde/retriever/models/test_e5-mistral-7b-instruct.py new file mode 100644 index 0000000000000..9084cd1df52c3 --- /dev/null +++ b/tests/wde/retriever/models/test_e5-mistral-7b-instruct.py @@ -0,0 +1,71 @@ +import random +from typing import TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import AutoModelForCausalLM, BatchEncoding, BatchFeature + +from tests.wde.utils import (SentenceTransformersRunner, VllmRunner, + compare_embeddings) + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +@pytest.fixture(scope="session") +def hf_runner(): + return SentenceTransformersRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = [ + "intfloat/e5-mistral-7b-instruct", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2, 3, 5, 7]) +@pytest.mark.parametrize("scheduling", ["sync", "async", "double_buffer"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype, + auto_cls=AutoModelForCausalLM) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/retriever/models/test_gte-Qwen2.py b/tests/wde/retriever/models/test_gte-Qwen2.py new file mode 100644 index 0000000000000..8104d97fb45dd --- /dev/null +++ b/tests/wde/retriever/models/test_gte-Qwen2.py @@ -0,0 +1,72 @@ +import random +from typing import TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import AutoModel, BatchEncoding, BatchFeature + +from tests.wde.utils import (SentenceTransformersRunner, VllmRunner, + compare_embeddings) + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +@pytest.fixture(scope="session") +def hf_runner(): + return SentenceTransformersRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = [ + "Alibaba-NLP/gte-Qwen2-1.5B-instruct", "Alibaba-NLP/gte-Qwen2-7B-instruct" +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [3]) +@pytest.mark.parametrize("scheduling", ["sync"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, dtype=dtype, auto_cls=AutoModel) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling, + switch_to_gte_Qwen2=True) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/retriever/models/test_snowflake-arctic-embed.py b/tests/wde/retriever/models/test_snowflake-arctic-embed.py new file mode 100644 index 0000000000000..29b278256356a --- /dev/null +++ b/tests/wde/retriever/models/test_snowflake-arctic-embed.py @@ -0,0 +1,84 @@ +import random +from typing import List, TypeVar + +import pytest +import torch +import torch.nn as nn +from transformers import AutoModel, BatchEncoding, BatchFeature + +from tests.wde.utils import HfRunner, VllmRunner, compare_embeddings + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class BertHfRunner(HfRunner): + + @torch.inference_mode + def encode(self, prompts: List[str]) -> torch.Tensor: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + embeddings = self.model(**encoded_input)[0][:, 0] + embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) + return embeddings + + +@pytest.fixture(scope="session") +def hf_runner(): + return BertHfRunner + + +@pytest.fixture(scope="session") +def vllm_runner(): + return VllmRunner + + +@pytest.fixture(scope="session") +def example_prompts(): + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] * 11 + random.shuffle(prompts) + return prompts + + +MODELS = ['Snowflake/snowflake-arctic-embed-xs'] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_num_seqs", [2]) +@pytest.mark.parametrize("scheduling", ["sync"]) +@torch.inference_mode +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_num_seqs: int, + scheduling: str, +) -> None: + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModel, + model_kwargs={"add_pooling_layer": False}) as hf_model: + hf_outputs = hf_model.encode(example_prompts) + + with vllm_runner(model, + dtype=dtype, + max_num_seqs=max_num_seqs, + scheduling=scheduling) as vllm_model: + vllm_outputs = vllm_model.encode(example_prompts) + + similarities = compare_embeddings(hf_outputs, vllm_outputs) + all_similarities = torch.stack(similarities) + tolerance = 1e-2 + assert torch.all((all_similarities <= 1.0 + tolerance) + & (all_similarities >= 1.0 - tolerance) + ), f"Not all values are within {tolerance} of 1.0" diff --git a/tests/wde/utils.py b/tests/wde/utils.py new file mode 100644 index 0000000000000..8a11424d6c792 --- /dev/null +++ b/tests/wde/utils.py @@ -0,0 +1,164 @@ +import gc +import os +from typing import Any, Dict, List, Optional, TypeVar + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import (AutoModelForCausalLM, AutoTokenizer, BatchEncoding, + BatchFeature) + +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, is_cpu +from vllm.wde.entrypoints.llm import LLM + +_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature) + + +class VllmRunner: + + def __init__(self, + model_name: str, + max_num_seqs: int = 4, + tokenizer_name: Optional[str] = None, + dtype: str = "half", + scheduling: str = "sync", + attention_impl: Optional[str] = None, + **kwargs) -> None: + if attention_impl is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = attention_impl + + self.model = LLM(model=model_name, + tokenizer=tokenizer_name, + trust_remote_code=True, + max_num_seqs=max_num_seqs, + dtype=dtype, + scheduling=scheduling, + **kwargs) + + if attention_impl is not None: + assert (self.model.llm_engine.attn_backend.get_name().lower() == + attention_impl.lower()) + + def encode(self, prompts: List[str]) -> List[List[float]]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs + outputs.append(embedding) + return outputs + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +class HfRunner: + + def wrap_device(self, input: _T) -> _T: + if not is_cpu(): + # Check if the input is already on the GPU + if hasattr(input, "device") and input.device.type == "cuda": + return input # Already on GPU, no need to move + return input.to("cuda") + else: + # Check if the input is already on the CPU + if hasattr(input, "device") and input.device.type == "cpu": + return input # Already on CPU, no need to move + return input.to("cpu") + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + auto_cls=AutoModelForCausalLM, + ) -> None: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + + model_kwargs = model_kwargs if model_kwargs is not None else {} + + self.model = self.wrap_device( + auto_cls.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + **model_kwargs, + )) + + self.tokenizer = AutoTokenizer.from_pretrained( + model_name, + torch_dtype=torch_dtype, + trust_remote_code=True, + ) + + @torch.inference_mode + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + encoded_input = self.tokenizer(prompts, + padding=True, + truncation=True, + return_tensors="pt").to("cuda") + + logits = self.model(**encoded_input).logits + seq_len = encoded_input.attention_mask.sum(axis=1) + + logits_list = [] + for e, s in zip(logits, seq_len): + logits_list.append(e[:s]) + return logits_list + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + del self.model + cleanup() + + +class SentenceTransformersRunner(HfRunner): + + def __init__( + self, + model_name: str, + dtype: str = "half", + *, + model_kwargs: Optional[Dict[str, Any]] = None, + auto_cls=AutoModelForCausalLM, + ) -> None: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] + + self.model_name = model_name + from sentence_transformers import SentenceTransformer + self.model = self.wrap_device( + SentenceTransformer(model_name, + device="cpu", + trust_remote_code=True).to(dtype=torch_dtype)) + + def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: + return self.model.encode(prompts, + convert_to_numpy=False, + normalize_embeddings=False) + + +def compare_embeddings(embeddings1, embeddings2): + similarities = [ + F.cosine_similarity(e1, e2, dim=0) + for e1, e2 in zip(embeddings1, embeddings2) + ] + return similarities + + +def compare_embeddings_np(embeddings1, embeddings2): + similarities = [e1 @ e2.T for e1, e2 in zip(embeddings1, embeddings2)] + return similarities + + +def cleanup(): + gc.collect() + if not is_cpu(): + torch.cuda.empty_cache() diff --git a/vllm/wde/__init__.py b/vllm/wde/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/__init__.py b/vllm/wde/core/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/arg_utils.py b/vllm/wde/core/arg_utils.py new file mode 100644 index 0000000000000..a10ee22c28cca --- /dev/null +++ b/vllm/wde/core/arg_utils.py @@ -0,0 +1,31 @@ +from dataclasses import dataclass, fields +from typing import List, Optional, Union + +from vllm.logger import init_logger + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@dataclass +class EngineArgs: + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + seed: int = 0 + + def to_dict(self): + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) diff --git a/vllm/wde/core/config.py b/vllm/wde/core/config.py new file mode 100644 index 0000000000000..504067aaaab4e --- /dev/null +++ b/vllm/wde/core/config.py @@ -0,0 +1,761 @@ +import enum +import json +from dataclasses import dataclass, field, fields +from typing import List, Optional, Union + +import torch +from transformers import PretrainedConfig + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.transformers_utils.config import get_config, get_hf_text_config +from vllm.utils import (is_cpu, is_hip, is_neuron, is_openvino, is_xpu, + print_warning_once) + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class DeviceConfig: + device: Optional[torch.device] + + def __init__(self, device: str = "auto") -> None: + if device == "auto": + # Automated device type detection + if is_neuron(): + self.device_type = "neuron" + elif is_openvino(): + self.device_type = "openvino" + elif is_cpu(): + self.device_type = "cpu" + elif is_xpu(): + self.device_type = "xpu" + else: + # We don't call torch.cuda.is_available() here to + # avoid initializing CUDA before workers are forked + self.device_type = "cuda" + else: + # Device type is assigned explicitly + self.device_type = device + + # Some device types require processing inputs on CPU + if self.device_type in ["neuron", "openvino"]: + self.device = torch.device("cpu") + elif self.device_type in ["tpu"]: + self.device = None + else: + # Set device with device type + self.device = torch.device(self.device_type) + + +class LoadFormat(str, enum.Enum): + AUTO = "auto" + PT = "pt" + SAFETENSORS = "safetensors" + NPCACHE = "npcache" + DUMMY = "dummy" + TENSORIZER = "tensorizer" + SHARDED_STATE = "sharded_state" + BITSANDBYTES = "bitsandbytes" + + +@dataclass +class LoadConfig: + """ + download_dir: Directory to download and load the weights, default to the + default cache directory of huggingface. + load_format: The format of the model weights to load: + "auto" will try to load the weights in the safetensors format and + fall back to the pytorch bin format if safetensors format is + not available. + "pt" will load the weights in the pytorch bin format. + "safetensors" will load the weights in the safetensors format. + "npcache" will load the weights in pytorch format and store + a numpy cache to speed up the loading. + "dummy" will initialize the weights with random values, which is + mainly for profiling. + "tensorizer" will use CoreWeave's tensorizer library for + fast weight loading. + "bitsandbytes" will load nf4 type weights. + ignore_patterns: The list of patterns to ignore when loading the model. + Default to "original/**/*" to avoid repeated loading of llama's + checkpoints. + + """ + + load_format: Union[str, LoadFormat] = LoadFormat.AUTO + download_dir: Optional[str] = None + model_loader_extra_config: Optional[Union[str, dict]] = field( + default_factory=dict) + ignore_patterns: Optional[Union[List[str], str]] = None + + def __post_init__(self): + model_loader_extra_config = self.model_loader_extra_config or {} + if isinstance(model_loader_extra_config, str): + self.model_loader_extra_config = json.loads( + model_loader_extra_config) + self._verify_load_format() + + if self.ignore_patterns is not None and len(self.ignore_patterns) > 0: + logger.info( + "Ignoring the following patterns when downloading weights: %s", + self.ignore_patterns) + else: + self.ignore_patterns = ["original/**/*"] + + def _verify_load_format(self) -> None: + if not isinstance(self.load_format, str): + return + + load_format = self.load_format.lower() + self.load_format = LoadFormat(load_format) + + rocm_not_supported_load_format: List[str] = [] + if is_hip() and load_format in rocm_not_supported_load_format: + rocm_supported_load_format = [ + f for f in LoadFormat.__members__ + if (f not in rocm_not_supported_load_format) + ] + raise ValueError( + f"load format '{load_format}' is not supported in ROCm. " + f"Supported load formats are " + f"{rocm_supported_load_format}") + + +class CacheConfig: + """Configuration for the KV cache. + + Args: + block_size: Size of a cache block in number of tokens. + gpu_memory_utilization: Fraction of GPU memory to use for the + vLLM execution. + swap_space: Size of the CPU swap space per GPU (in GiB). + cache_dtype: Data type for kv cache storage. + num_gpu_blocks_override: Number of GPU blocks to use. This overrides the + profiled num_gpu_blocks if specified. Does nothing if None. + """ + + def __init__( + self, + block_size: int, + gpu_memory_utilization: float, + swap_space: int, + cache_dtype: str, + num_gpu_blocks_override: Optional[int] = None, + sliding_window: Optional[int] = None, + enable_prefix_caching: bool = False, + cpu_offload_gb: float = 0, + ) -> None: + self.block_size = block_size + self.gpu_memory_utilization = gpu_memory_utilization + self.swap_space_bytes = swap_space * _GB + self.num_gpu_blocks_override = num_gpu_blocks_override + self.cache_dtype = cache_dtype + self.sliding_window = sliding_window + self.enable_prefix_caching = enable_prefix_caching + self.cpu_offload_gb = cpu_offload_gb + self._verify_args() + self._verify_cache_dtype() + self._verify_prefix_caching() + + # Will be set after profiling. + self.num_gpu_blocks = None + self.num_cpu_blocks = None + + def metrics_info(self): + # convert cache_config to dict(key: str, value: str) for prometheus + # metrics info + return {key: str(value) for key, value in self.__dict__.items()} + + def _verify_args(self) -> None: + if self.gpu_memory_utilization > 1.0: + raise ValueError( + "GPU memory utilization must be less than 1.0. Got " + f"{self.gpu_memory_utilization}.") + + def _verify_cache_dtype(self) -> None: + if self.cache_dtype == "auto": + pass + elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor") + else: + raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") + + def _verify_prefix_caching(self) -> None: + if not self.enable_prefix_caching: + return + + if self.sliding_window is not None: + raise NotImplementedError( + "Prefix caching is not supported with sliding window. " + "Run with --disable-sliding-window to use prefix caching.") + if self.cache_dtype == "fp8": + raise NotImplementedError( + "Prefix caching is not supported for fp8 cache_dtype. " + "Run with --kv-cache-dtype auto to use prefix caching.") + + +class ModelConfig: + """Configuration for the model. + + Args: + model: Name or path of the huggingface model to use. + It is also used as the content for `model_name` tag in metrics + output when `served_model_name` is not specified. + tokenizer: Name or path of the huggingface tokenizer to use. + tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if + available, and "slow" will always use the slow tokenizer. + trust_remote_code: Trust remote code (e.g., from HuggingFace) when + downloading the model and tokenizer. + dtype: Data type for model weights and activations. The "auto" option + will use FP16 precision for FP32 and FP16 models, and BF16 precision + for BF16 models. + seed: Random seed for reproducibility. + revision: The specific model version to use. It can be a branch name, + a tag name, or a commit id. If unspecified, will use the default + version. + code_revision: The specific revision to use for the model code on + Hugging Face Hub. It can be a branch name, a tag name, or a + commit id. If unspecified, will use the default version. + rope_scaling: Dictionary containing the scaling configuration for the + RoPE embeddings. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + tokenizer_revision: The specific tokenizer version to use. It can be a + branch name, a tag name, or a commit id. If unspecified, will use + the default version. + max_model_len: Maximum length of a sequence (including prompt and + output). If None, will be derived from the model. + quantization: Quantization method that was used to quantize the model + weights. If None, we assume the model weights are not quantized. + quantization_param_path: Path to JSON file containing scaling factors. + Used to load KV cache scaling factors into the model when KV cache + type is FP8_E4M3 on ROCm (AMD GPU). In the future these will also + be used to load activation and weight scaling factors when the + model dtype is FP8_E4M3 on ROCm. + disable_sliding_window: Whether to disable sliding window. If True, + we will disable the sliding window functionality of the model. + If the model does not support sliding window, this argument is + ignored. + skip_tokenizer_init: If true, skip initialization of tokenizer and + detokenizer. + served_model_name: The model name used in metrics tag `model_name`, + matches the model name exposed via the APIs. If multiple model + names provided, the first name will be used. If not specified, + the model name will be the same as `model`. + """ + + def __init__( + self, + model: str, + tokenizer: str, + tokenizer_mode: str, + trust_remote_code: bool, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + disable_sliding_window: bool = False, + skip_tokenizer_init: bool = False, + served_model_name: Optional[Union[str, List[str]]] = None, + ) -> None: + self.model = model + self.tokenizer = tokenizer + self.tokenizer_mode = tokenizer_mode + self.trust_remote_code = trust_remote_code + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + # The tokenizer version is consistent with the model version by default. + if tokenizer_revision is None: + self.tokenizer_revision = revision + else: + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.disable_sliding_window = disable_sliding_window + self.skip_tokenizer_init = skip_tokenizer_init + + self.hf_config = get_config(self.model, trust_remote_code, revision, + code_revision, rope_scaling, rope_theta) + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + + if (not self.disable_sliding_window + and self.hf_text_config.model_type == "gemma2" + and self.hf_text_config.sliding_window is not None): + print_warning_once( + "Gemma 2 uses sliding window attention for every odd layer, " + "which is currently not supported by vLLM. Disabling sliding " + "window and capping the max length to the sliding window size " + f"({self.hf_text_config.sliding_window}).") + self.disable_sliding_window = True + + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window()) + self.served_model_name = get_served_model_name(model, + served_model_name) + + if not self.skip_tokenizer_init: + self._verify_tokenizer_mode() + self._verify_quantization() + + def _verify_tokenizer_mode(self) -> None: + tokenizer_mode = self.tokenizer_mode.lower() + if tokenizer_mode not in ["auto", "slow"]: + raise ValueError( + f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " + "either 'auto' or 'slow'.") + self.tokenizer_mode = tokenizer_mode + + def _parse_quant_hf_config(self): + quant_cfg = getattr(self.hf_config, "quantization_config", None) + if quant_cfg is None: + # compressed-tensors uses a "compression_config" key + quant_cfg = getattr(self.hf_config, "compression_config", None) + return quant_cfg + + def _verify_quantization(self) -> None: + supported_quantization = [*QUANTIZATION_METHODS] + rocm_supported_quantization = ["gptq", "squeezellm"] + optimized_quantization_methods = [ + "fp8", "marlin", "gptq_marlin_24", "gptq_marlin", "awq_marlin", + "fbgemm_fp8", "compressed_tensors", "compressed-tensors" + ] + if self.quantization is not None: + self.quantization = self.quantization.lower() + + # Parse quantization method from the HF model config, if available. + quant_cfg = self._parse_quant_hf_config() + + if quant_cfg is not None: + quant_method = quant_cfg.get("quant_method", "").lower() + + # Detect which checkpoint is it + for _, method in QUANTIZATION_METHODS.items(): + quantization_override = method.override_quantization_method( + quant_cfg, self.quantization) + if quantization_override: + quant_method = quantization_override + self.quantization = quantization_override + break + + # Verify quantization configurations. + if self.quantization is None: + self.quantization = quant_method + elif self.quantization != quant_method: + raise ValueError( + "Quantization method specified in the model config " + f"({quant_method}) does not match the quantization " + f"method specified in the `quantization` argument " + f"({self.quantization}).") + + if self.quantization is not None: + if self.quantization not in supported_quantization: + raise ValueError( + f"Unknown quantization method: {self.quantization}. Must " + f"be one of {supported_quantization}.") + if is_hip( + ) and self.quantization not in rocm_supported_quantization: + raise ValueError( + f"{self.quantization} quantization is currently not " + f"supported in ROCm.") + if self.quantization not in optimized_quantization_methods: + logger.warning( + "%s quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models.", self.quantization) + + def get_hf_config_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled.""" + + # Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in + # addition to sliding window size. We check if that field is present + # and if it's False, return None. + if (hasattr(self.hf_text_config, "use_sliding_window") + and not self.hf_text_config.use_sliding_window): + return None + return getattr(self.hf_text_config, "sliding_window", None) + + def get_sliding_window(self) -> Optional[int]: + """Get the sliding window size, or None if disabled. + """ + # If user disables sliding window, return None. + if self.disable_sliding_window: + return None + # Otherwise get the value from the hf config. + return self.get_hf_config_sliding_window() + + def get_vocab_size(self) -> int: + return self.hf_text_config.vocab_size + + def get_hidden_size(self) -> int: + return self.hf_text_config.hidden_size + + def get_head_size(self) -> int: + # TODO remove hard code + if hasattr(self.hf_text_config, "model_type" + ) and self.hf_text_config.model_type == 'deepseek_v2': + # FlashAttention supports only head_size 32, 64, 128, 256, + # we need to pad head_size 192 to 256 + return 256 + if hasattr(self.hf_text_config, "head_dim"): + return self.hf_text_config.head_dim + # FIXME(woosuk): This may not be true for all models. + return (self.hf_text_config.hidden_size // + self.hf_text_config.num_attention_heads) + + def get_total_num_kv_heads(self) -> int: + """Returns the total number of KV heads.""" + # For GPTBigCode & Falcon: + # NOTE: for falcon, when new_decoder_architecture is True, the + # multi_query flag is ignored and we use n_head_kv for the number of + # KV heads. + falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] + new_decoder_arch_falcon = ( + self.hf_config.model_type in falcon_model_types + and getattr(self.hf_config, "new_decoder_architecture", False)) + if not new_decoder_arch_falcon and getattr(self.hf_text_config, + "multi_query", False): + # Multi-query attention, only one KV head. + # Currently, tensor parallelism is not supported in this case. + return 1 + + # For DBRX and MPT + if self.hf_config.model_type == "mpt": + if "kv_n_heads" in self.hf_config.attn_config: + return self.hf_config.attn_config["kv_n_heads"] + return self.hf_config.num_attention_heads + if self.hf_config.model_type == "dbrx": + return getattr(self.hf_config.attn_config, "kv_n_heads", + self.hf_config.num_attention_heads) + + attributes = [ + # For Falcon: + "n_head_kv", + "num_kv_heads", + # For LLaMA-2: + "num_key_value_heads", + # For ChatGLM: + "multi_query_group_num", + ] + for attr in attributes: + num_kv_heads = getattr(self.hf_text_config, attr, None) + if num_kv_heads is not None: + return num_kv_heads + + # For non-grouped-query attention models, the number of KV heads is + # equal to the number of attention heads. + return self.hf_text_config.num_attention_heads + + def get_num_kv_heads(self) -> int: + """Returns the number of KV heads per GPU.""" + total_num_kv_heads = self.get_total_num_kv_heads() + # If tensor parallelism is used, we divide the number of KV heads by + # the tensor parallel size. We will replicate the KV heads in the + # case where the number of KV heads is smaller than the tensor + # parallel size so each GPU has at least one KV head. + return max(1, total_num_kv_heads) + + def get_num_attention_heads(self) -> int: + num_heads = getattr(self.hf_text_config, "num_attention_heads", 0) + return num_heads + + def get_num_layers(self) -> int: + + total_num_hidden_layers = getattr(self.hf_text_config, + "num_hidden_layers", 0) + + return total_num_hidden_layers + + def get_layers_block_type(self) -> List[str]: + num_layers = self.get_num_layers() + # Transformers supports layers_block_type @property + return getattr(self.hf_config, "layers_block_type", + ["attention"] * num_layers) + + def get_num_attention_layers(self) -> int: + return len( + [t for t in self.get_layers_block_type() if t == "attention"]) + + +class SchedulerConfig: + pass + + +class ParallelConfig: + pass + + +_STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.float16, + "float16": torch.float16, + "float": torch.float32, + "float32": torch.float32, + "bfloat16": torch.bfloat16, +} + + +def _get_and_verify_dtype( + config: PretrainedConfig, + dtype: Union[str, torch.dtype], +) -> torch.dtype: + # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct + # because config.torch_dtype can be None. + config_dtype = getattr(config, "torch_dtype", None) + if config_dtype is None: + config_dtype = torch.float32 + + if isinstance(dtype, str): + dtype = dtype.lower() + if dtype == "auto": + if config_dtype == torch.float32: + if config.model_type == "gemma2": + logger.info( + "For Gemma 2, we downcast float32 to bfloat16 instead " + "of float16 by default. Please specify `dtype` if you " + "want to use float16.") + torch_dtype = torch.bfloat16 + else: + # Following the common practice, we use float16 for float32 + # models. + torch_dtype = torch.float16 + else: + torch_dtype = config_dtype + else: + if dtype not in _STR_DTYPE_TO_TORCH_DTYPE: + raise ValueError(f"Unknown dtype: {dtype}") + torch_dtype = _STR_DTYPE_TO_TORCH_DTYPE[dtype] + elif isinstance(dtype, torch.dtype): + torch_dtype = dtype + else: + raise ValueError(f"Unknown dtype: {dtype}") + + # Verify the dtype. + if torch_dtype != config_dtype: + if torch_dtype == torch.float32: + # Upcasting to float32 is allowed. + logger.info("Upcasting %s to %s.", config_dtype, torch_dtype) + pass + elif config_dtype == torch.float32: + # Downcasting from float32 to float16 or bfloat16 is allowed. + logger.info("Downcasting %s to %s.", config_dtype, torch_dtype) + pass + else: + # Casting between float16 and bfloat16 is allowed with a warning. + logger.warning("Casting %s to %s.", config_dtype, torch_dtype) + + return torch_dtype + + +def get_served_model_name(model: str, + served_model_name: Optional[Union[str, List[str]]]): + """ + If the input is a non-empty list, the first model_name in + `served_model_name` is taken. + If the input is a non-empty string, it is used directly. + For cases where the input is either an empty string or an + empty list, the fallback is to use `self.model`. + """ + if not served_model_name: + return model + if isinstance(served_model_name, list): + return served_model_name[0] + return served_model_name + + +def _get_and_verify_max_len( + hf_config: PretrainedConfig, + max_model_len: Optional[int], + disable_sliding_window: bool, + sliding_window_len: Optional[int], +) -> int: + """Get and verify the model's maximum length.""" + derived_max_model_len = float("inf") + possible_keys = [ + # OPT + "max_position_embeddings", + # GPT-2 + "n_positions", + # MPT + "max_seq_len", + # ChatGLM2 + "seq_length", + # Command-R + "model_max_length", + # Others + "max_sequence_length", + "max_seq_length", + "seq_len", + ] + # Choose the smallest "max_length" from the possible keys. + max_len_key = None + for key in possible_keys: + max_len = getattr(hf_config, key, None) + if max_len is not None: + max_len_key = key if max_len < derived_max_model_len \ + else max_len_key + derived_max_model_len = min(derived_max_model_len, max_len) + + # If sliding window is manually disabled, max_length should be less + # than the sliding window length in the model config. + if disable_sliding_window and sliding_window_len is not None: + max_len_key = "sliding_window" \ + if sliding_window_len < derived_max_model_len else max_len_key + derived_max_model_len = min(derived_max_model_len, sliding_window_len) + + # If none of the keys were found in the config, use a default and + # log a warning. + if derived_max_model_len == float("inf"): + if max_model_len is not None: + # If max_model_len is specified, we use it. + return max_model_len + + default_max_len = 2048 + logger.warning( + "The model's config.json does not contain any of the following " + "keys to determine the original maximum length of the model: " + "%s. Assuming the model's maximum length is %d.", possible_keys, + default_max_len) + derived_max_model_len = default_max_len + + rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None: + if "type" in rope_scaling: + rope_type = rope_scaling["type"] + elif "rope_type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + else: + raise ValueError( + "rope_scaling must have a 'type' or 'rope_type' key.") + + # The correct one should be "longrope", kept "su" here + # to be backward compatible + if rope_type not in ("su", "longrope", "llama3"): + if disable_sliding_window: + # TODO(robertgshaw): Find a model that supports rope_scaling + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "with rope_scaling. Please raise an issue so we can " + "investigate.") + + assert "factor" in rope_scaling + scaling_factor = rope_scaling["factor"] + if rope_type == "yarn": + derived_max_model_len = rope_scaling[ + "original_max_position_embeddings"] + derived_max_model_len *= scaling_factor + + # If the user specified a max length, make sure it is smaller than the + # derived length from the HF model config. + if max_model_len is None: + max_model_len = int(derived_max_model_len) + elif max_model_len > derived_max_model_len: + # Some models might have a separate key for specifying model_max_length + # that will be bigger than derived_max_model_len. We compare user input + # with model_max_length and allow this override when it's smaller. + model_max_length = getattr(hf_config, "model_max_length", None) + if model_max_length is not None and max_model_len <= model_max_length: + if disable_sliding_window: + # TODO(robertgshaw): Find a model that has model_max_length + # with sliding window to see if this case should be allowed. + raise NotImplementedError( + "Disabling sliding window is not supported for models " + "model_max_length in the config. Please raise an issue " + "so we can investigate.") + else: + msg = ( + f"User-specified max_model_len ({max_model_len}) is greater " + f"than the derived max_model_len ({max_len_key}=" + f"{derived_max_model_len} or model_max_length=" + f"{model_max_length} in model's config.json). This may lead " + "to incorrect model outputs or CUDA errors.") + if envs.VLLM_ALLOW_LONG_MAX_MODEL_LEN: + logger.warning( + "%s Make sure the value is correct and within the " + "model context size.", msg) + else: + raise ValueError( + f"{msg} To allow overriding this maximum, set " + "the env var VLLM_ALLOW_LONG_MAX_MODEL_LEN=1") + return int(max_model_len) + + +@dataclass(frozen=True) +class EngineConfig: + model_config: ModelConfig + device_config: DeviceConfig + load_config: LoadConfig + scheduler_config: SchedulerConfig + parallel_config: Optional[ParallelConfig] + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + logger.info( + "Initializing an LLM engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "rope_scaling=%r, rope_theta=%r, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "quantization=%s, " + "quantization_param_path=%s, device_config=%s, " + "seed=%d, served_model_name=%s)", + VLLM_VERSION, + self.model_config.model, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.rope_scaling, + self.model_config.rope_theta, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.model_config.quantization, + self.model_config.quantization_param_path, + self.device_config.device, + self.model_config.seed, + self.model_config.served_model_name, + ) + + +def filter_unexpected_fields(cls): + original_init = cls.__init__ + + def new_init(self, *args, **kwargs): + expected_fields = {field.name for field in fields(cls)} + cleaned_kwargs = { + key: value + for key, value in kwargs.items() if key in expected_fields + } + original_init(self, *args, **cleaned_kwargs) + + cls.__init__ = new_init + return cls diff --git a/vllm/wde/core/inputs/__init__.py b/vllm/wde/core/inputs/__init__.py new file mode 100644 index 0000000000000..183fd6023e5ba --- /dev/null +++ b/vllm/wde/core/inputs/__init__.py @@ -0,0 +1,12 @@ +from .registry import InputContext, InputRegistry + +INPUT_REGISTRY = InputRegistry() +""" +The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine` +to dispatch data processing according to the target model. + +See also: + :ref:`input_processing_pipeline` +""" + +__all__ = ["INPUT_REGISTRY", "InputContext", "InputRegistry"] diff --git a/vllm/wde/core/inputs/registry.py b/vllm/wde/core/inputs/registry.py new file mode 100644 index 0000000000000..86c08690aa634 --- /dev/null +++ b/vllm/wde/core/inputs/registry.py @@ -0,0 +1,177 @@ +import functools +from dataclasses import dataclass +from typing import Any, Callable, Dict, Type, TypeVar + +from torch import nn +from transformers import PretrainedConfig + +from vllm.config import ModelConfig +from vllm.logger import init_logger +from vllm.wde.core.schema.engine_io import TextOnlyInputs + +logger = init_logger(__name__) + +C = TypeVar("C", bound=PretrainedConfig) + + +@dataclass(frozen=True) +class InputContext: + """ + Contains information about the model which may be used to + modify the inputs. + """ + + model_config: "ModelConfig" + """The configuration of the model.""" + + def get_hf_config(self, hf_config_type: Type[C]) -> C: + """ + Get the HuggingFace configuration + (:class:`transformers.PretrainedConfig`) of the model, + additionally checking its type. + + Raises: + TypeError: If the model is not of the specified type. + """ + + hf_config = self.model_config.hf_config + if not isinstance(hf_config, hf_config_type): + raise TypeError("Invalid type of HuggingFace config. " + f"Expected type: {hf_config_type}, but " + f"found type: {type(hf_config)}") + + return hf_config + + +N = TypeVar("N", bound=Type[nn.Module]) + +InputProcessor = Callable[[InputContext, TextOnlyInputs], TextOnlyInputs] +"""Preprocess the inputs to the model.""" + + +class InputRegistry: + """ + A registry to dispatch data processing + according to the target model. + """ + + def __init__(self) -> None: + self._dummy_factories_by_model_type: Dict[Type[nn.Module], Any] = {} + self._input_processors_by_model_type: Dict[Type[nn.Module], + InputProcessor] = {} + + def _default_dummy_data_factory( + self, + ctx: InputContext, + seq_len: int, + ): + """ + The default dummy data factory represents the longest possible text + that can be inputted to the model. + + Note: + :data:`InputProcessor` is not applied to the dummy data. + """ + # Avoid circular import + from vllm.sequence import SequenceData + + dummy_seq_data = SequenceData([0] * seq_len) + dummy_multi_modal_data = None + + return dummy_seq_data, dummy_multi_modal_data + + def register_dummy_data(self, factory): + """ + Register a dummy data factory to a model class. + + During memory profiling, the provided function is invoked to create + dummy data to be inputted into the model. The resulting memory usage + should be an upper bound of what the model would use at inference time. + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._dummy_factories_by_model_type: + logger.warning( + "Model class %s already has dummy data " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._dummy_factories_by_model_type[model_cls] = factory + + return model_cls + + return wrapper + + def dummy_data_for_profiling(self, model_config, seq_len: int): + """ + Create dummy data for profiling the memory usage of a model. + + The model is identified by ``model_config``. + + See also: + :ref:`enabling_multimodal_inputs` + """ + # Avoid circular import + from vllm.wde.core.loader.utils import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + dummy_factory = self._dummy_factories_by_model_type \ + .get(model_cls, self._default_dummy_data_factory) + + return dummy_factory(InputContext(model_config), seq_len) + + def _default_input_processor(self, ctx: InputContext, + inputs: TextOnlyInputs) -> TextOnlyInputs: + """The default input processor is a no-op.""" + return inputs + + def register_input_processor(self, processor: InputProcessor): + """ + Register an input processor to a model class. + + The provided function is invoked on each input to the model. This + happens before :meth:`~vllm.multimodal.MultiModalRegistry.map_input`. + + See also: + :ref:`input_processing_pipeline` + """ + + def wrapper(model_cls: N) -> N: + if model_cls in self._input_processors_by_model_type: + logger.warning( + "Model class %s already has input processor " + "registered to %s. It is overwritten by the new one.", + model_cls, self) + + self._input_processors_by_model_type[model_cls] = processor + + return model_cls + + return wrapper + + def process_input(self, model_config, + inputs: TextOnlyInputs) -> TextOnlyInputs: + """ + Apply an input processor to an instance of model inputs. + + The model is identified by ``model_config``. + + See also: + :ref:`input_processing_pipeline` + """ + # Avoid circular import + from vllm.wde.core.loader.utils import get_model_architecture + + model_cls, _ = get_model_architecture(model_config) + + processor = self._input_processors_by_model_type \ + .get(model_cls, self._default_input_processor) + + return processor(InputContext(model_config), inputs) + + def create_input_processor(self, model_config): + """ + Create an input processor (see :meth:`process_input`) for a + specific model. + """ + return functools.partial(self.process_input, model_config) diff --git a/vllm/wde/core/inputs/tokenizer.py b/vllm/wde/core/inputs/tokenizer.py new file mode 100644 index 0000000000000..2fce62106879f --- /dev/null +++ b/vllm/wde/core/inputs/tokenizer.py @@ -0,0 +1,485 @@ +from typing import Dict, List, Optional, Tuple, Union + +from transformers import (AutoTokenizer, PreTrainedTokenizer, + PreTrainedTokenizerFast) + +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.sequence import Logprob, SamplingParams, Sequence, SequenceGroup + +logger = init_logger(__name__) + +INVALID_TOKEN_ID = -1 + + +class Tokenizer(object): + + def __init__(self, tokenizer_name: str, **kwargs): + self.tokenizer_name = tokenizer_name + self.tokenizer_kwargs = kwargs + + # layzer_load + self._tokenizer = None + + @property + def tokenizer(self): + if self._tokenizer is None: + self._tokenizer = get_tokenizer(tokenizer_name=self.tokenizer_name, + **self.tokenizer_kwargs) + + return self._tokenizer + + @classmethod + def from_engine(cls, engine): + init_kwargs = dict( + tokenizer_name=engine.engine_config.model_config.tokenizer, + tokenizer_mode=engine.engine_config.model_config.tokenizer_mode, + trust_remote_code=engine.engine_config.model_config. + trust_remote_code, + revision=engine.engine_config.model_config.tokenizer_revision) + + return cls(**init_kwargs) + + def __call__(self, *args, **kwargs): + return self.tokenizer(*args, **kwargs) + + def apply_chat_template(self, *args, **kwargs): + return self.tokenizer.apply_chat_template(*args, **kwargs) + + def encode(self, *args, **kwargs): + return self.tokenizer.encode(*args, **kwargs) + + @property + def eos_token_id(self): + return self.tokenizer.eos_token_id + + def decode_prompt_logprobs_inplace(self, seq_group: SequenceGroup, + prompt_logprobs: List[Optional[Dict[ + int, Logprob]]], + position_offset: int) -> None: + """Decodes the logprobs for the prompt of a sequence group. + + Args: + seq_group: The sequence group to decode. + prompt_logprobs: The logprobs to decode. + position_offset: Offset of the first index of the logprobs + relative to the start of the sequence (for chunked prefill). + + Returns: + The prompt logprobs with the decoded tokens. + """ + prms = seq_group.sampling_params + assert prms is not None + + # We can pick any sequence for the prompt. + seq = seq_group.get_seqs()[0] + # Only prompt, without the generated token. + all_token_ids = seq.get_token_ids() + prompt_token_ids = all_token_ids[:-1] + tokenizer = self.get_tokenizer_for_seq(seq) + prefix_offset = 0 + read_offset = 0 + next_iter_prefix_offset = 0 + next_iter_read_offset = 0 + next_iter_tokens: List[str] = [] + prev_tokens = None + + for token_position_in_logprob, prompt_logprobs_for_token in enumerate( + prompt_logprobs): + + # Absolute token position equals the index in the logprobs + # list plus the offset of the entire logprobs list relative + # to the start of the sequence. + token_position = token_position_in_logprob + position_offset + if not prompt_logprobs_for_token: + continue + for token_id, sample_logprob in prompt_logprobs_for_token.items(): + if (sample_logprob.decoded_token is None + and token_id != INVALID_TOKEN_ID): + prompt_token_ids_with_token = ( + prompt_token_ids[:token_position] + [token_id]) + (new_tokens, new_text, new_prefix_offset, + new_read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=prompt_token_ids_with_token, + prev_tokens=prev_tokens, + prefix_offset=prefix_offset, + read_offset=read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + + sample_logprob.decoded_token = new_text + + # Use the offsets & prev tokens corresponding to + # real tokens to ensure detokenization is consistent + # actual with prompt. + if token_id == all_token_ids[token_position]: + next_iter_prefix_offset = new_prefix_offset + next_iter_read_offset = new_read_offset + next_iter_tokens = new_tokens + + # Advance to the next token position. + prefix_offset = next_iter_prefix_offset + read_offset = next_iter_read_offset + if prev_tokens is None: + prev_tokens = next_iter_tokens + else: + prev_tokens.extend(next_iter_tokens) + + def decode_sequence_inplace(self, seq: Sequence, + prms: SamplingParams) -> int: + """Decodes the new token for a sequence. In-place operation. + + Args: + seq: The sequence to decode. + prms: The sampling parameters used to generate the sequence. + + Returns: + The number of characters added to the output text. + """ + all_input_ids = seq.get_token_ids() + token_id_generated_this_iteration = all_input_ids[-1] + tokenizer = self.get_tokenizer_for_seq(seq) + + # Convert prompt token IDs to tokens if necessary. + # Do it here so that we don't have to repeat this + # computation for each logprob. + if seq.tokens is None: + (seq.tokens, seq.prefix_offset, + seq.read_offset) = convert_prompt_ids_to_tokens( + tokenizer=tokenizer, + prompt_ids=all_input_ids[:-1], + skip_special_tokens=prms.skip_special_tokens, + ) + + (new_tokens, new_decoded_token_text, prefix_offset, + read_offset) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + + # Decode logprobs + logprobs = seq.output_logprobs[-1] + if logprobs: + previous_tokens = all_input_ids[:-1] + for token_id, sample_logprob in logprobs.items(): + # If the token was generated this iteration, + # use the provided text. + if token_id == token_id_generated_this_iteration: + sample_logprob.decoded_token = new_decoded_token_text + continue + + if (sample_logprob.decoded_token is None + and token_id != INVALID_TOKEN_ID): + all_input_ids_with_logprob = previous_tokens + [token_id] + (_, new_text, _, _) = detokenize_incrementally( + tokenizer=tokenizer, + all_input_ids=all_input_ids_with_logprob, + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms. + spaces_between_special_tokens, + ) + sample_logprob.decoded_token = new_text + + seq.tokens.extend(new_tokens) + seq.prefix_offset = prefix_offset + seq.read_offset = read_offset + seq.output_text += new_decoded_token_text + + return len(new_decoded_token_text) + + +def get_cached_tokenizer( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Get tokenizer with cached properties. + + This will patch the tokenizer object in place. + + By default, transformers will recompute multiple tokenizer properties + each time they are called, leading to a significant slowdown. This + function caches these properties for faster access.""" + + tokenizer_all_special_ids = set(tokenizer.all_special_ids) + tokenizer_all_special_tokens_extended = ( + tokenizer.all_special_tokens_extended) + tokenizer_all_special_tokens = set(tokenizer.all_special_tokens) + tokenizer_len = len(tokenizer) + + class CachedTokenizer(tokenizer.__class__): # type: ignore + + @property + def all_special_ids(self): + return tokenizer_all_special_ids + + @property + def all_special_tokens(self): + return tokenizer_all_special_tokens + + @property + def all_special_tokens_extended(self): + return tokenizer_all_special_tokens_extended + + def __len__(self): + return tokenizer_len + + CachedTokenizer.__name__ = f"Cached{tokenizer.__class__.__name__}" + + tokenizer.__class__ = CachedTokenizer + return tokenizer + + +def get_tokenizer( + tokenizer_name: str, + *args, + tokenizer_mode: str = "auto", + trust_remote_code: bool = False, + revision: Optional[str] = None, + download_dir: Optional[str] = None, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via HuggingFace or ModelScope. + """ + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + import os + + import huggingface_hub + from modelscope.hub.snapshot_download import snapshot_download + + # Only set the tokenizer here, model will be downloaded on the workers. + if not os.path.exists(tokenizer_name): + tokenizer_path = snapshot_download( + model_id=tokenizer_name, + cache_dir=download_dir, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + # Ignore weights - we only need the tokenizer. + ignore_file_pattern=[".*.pt", ".*.safetensors", ".*.bin"]) + tokenizer_name = tokenizer_path + + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError( + "Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "truncation_side" not in kwargs: + kwargs["truncation_side"] = "left" + + try: + tokenizer = AutoTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, suggest using the --trust-remote-code flag. + if (not trust_remote_code and + ("does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e))): + err_msg = ( + "Failed to load the tokenizer. If the tokenizer is a custom " + "tokenizer not yet available in the HuggingFace transformers " + "library, consider setting `trust_remote_code=True` in LLM " + "or using the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + except AttributeError as e: + if "BaichuanTokenizer" in str(e): + # This is for the error "'BaichuanTokenizer' object has no + # attribute 'sp_model'". + from vllm.transformers_utils.tokenizers import BaichuanTokenizer + tokenizer = BaichuanTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs) + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + return get_cached_tokenizer(tokenizer) + + +def _replace_none_with_empty(tokens: List[Optional[str]]): + for i, token in enumerate(tokens): + if token is None: + tokens[i] = "" + + +def _convert_tokens_to_string_with_added_encoders( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + output_tokens: List[str], + skip_special_tokens: bool, + spaces_between_special_tokens: bool, +) -> str: + # Adapted from + # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/tokenization_utils.py#L921 + # NOTE(woosuk): The following code is slow because it runs a for loop over + # the output_tokens. In Python, running a for loop over a list can be slow + # even when the loop body is very simple. + sub_texts: List[str] = [] + current_sub_text: List[str] = [] + all_special_tokens = set(tokenizer.all_special_tokens) + for token in output_tokens: + if skip_special_tokens and token in all_special_tokens: + continue + if token in tokenizer.get_added_vocab(): + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + current_sub_text = [] + sub_texts.append(token) + else: + current_sub_text.append(token) + if current_sub_text: + sub_text = tokenizer.convert_tokens_to_string(current_sub_text) + sub_texts.append(sub_text) + if spaces_between_special_tokens: + return " ".join(sub_texts) + else: + return "".join(sub_texts) + + +# 5 is an arbitrary value that should work for all +# tokenizers (bigger = more conservative). +INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET = 5 + + +def convert_prompt_ids_to_tokens( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + prompt_ids: List[int], + skip_special_tokens: bool = False, +) -> Tuple[List[str], int, int]: + """Converts the prompt ids to tokens and returns the tokens and offsets + for incremental detokenization. + + Note that not all tokens are converted to strings. Only the tokens that + are necessary for incremental detokenization are converted to strings. + """ + # We do not need to convert the whole prompt to tokens. + # Offset a little more in case we have special tokens. + new_tokens = tokenizer.convert_ids_to_tokens( + prompt_ids[-INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET - 2:], + skip_special_tokens=skip_special_tokens) + read_offset = len(new_tokens) + prefix_offset = max( + read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) + # This is required to guard against out-of-vocab prompt token ids + _replace_none_with_empty(new_tokens) + return new_tokens, prefix_offset, read_offset + + +# Based on +# https://github.com/huggingface/text-generation-inference/blob/v0.9.4/server/text_generation_server/models/model.py#L62C9-L62C15 +# under Apache 2.0 license +def detokenize_incrementally( + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + all_input_ids: List[int], + prev_tokens: Optional[List[str]], + prefix_offset: int, + read_offset: int, + skip_special_tokens: bool = False, + spaces_between_special_tokens: bool = True, +) -> Tuple[List[str], str, int, int]: + """Detokenizes the input ids incrementally and returns the new tokens + and the new text. + + If `prev_tokens` is None, this function will convert the input ids to + tokens and return the tokens and the new text. Otherwise, it will return the + new tokens and the new text. + + This function will also return the new prefix offset and the new read + offset to be used in the next iteration. + + The offsets are necessary to defeat cleanup algorithms in the decode which + decide to add a space or not depending on the surrounding ids. + + Args: + tokenizer: The tokenizer to use. + all_input_ids: The input ids. The last id is the new token id. + prev_tokens: The previous tokens. If None, this function will convert + the input ids to tokens and return the tokens and the new text. + prefix_offset: The prefix offset. + read_offset: The read offset. + skip_special_tokens: Whether to skip special tokens. + spaces_between_special_tokens: Whether to add spaces between special + tokens. + """ + new_token_id = all_input_ids[-1] + # This is the first iteration for this sequence + is_first_iter = prev_tokens is None + if is_first_iter: + (prev_tokens, prefix_offset, + read_offset) = convert_prompt_ids_to_tokens( + tokenizer, + all_input_ids[:-1], + skip_special_tokens=skip_special_tokens) + assert prev_tokens is not None + + # If the new token id is out of bounds, return an empty string. + if new_token_id >= len(tokenizer): + new_tokens = [""] + else: + # Put new_token_id in a list so skip_special_tokens is respected + new_tokens = tokenizer.convert_ids_to_tokens( + [new_token_id], skip_special_tokens=skip_special_tokens) + if isinstance(new_tokens, str): + new_tokens = [new_tokens] + output_tokens = prev_tokens + new_tokens + + # If this is the first iteration, return all tokens. + if is_first_iter: + new_tokens = output_tokens + + # The prefix text is necessary only to defeat cleanup algorithms in + # the decode which decide to add a space or not depending on the + # surrounding ids. + if tokenizer.is_fast or not tokenizer.get_added_vocab(): + prefix_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:read_offset]) + new_text = tokenizer.convert_tokens_to_string( + output_tokens[prefix_offset:]) + else: + prefix_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:read_offset], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + new_text = _convert_tokens_to_string_with_added_encoders( + tokenizer, + output_tokens[prefix_offset:], + skip_special_tokens=skip_special_tokens, + spaces_between_special_tokens=spaces_between_special_tokens, + ) + + if len(new_text) <= len(prefix_text) or new_text.endswith("�"): + # utf-8 char at the end means it's a potential unfinished byte sequence + # from byte fallback tokenization. + # If it's in the middle, it's probably a real invalid id generated + # by the model + return new_tokens, "", prefix_offset, read_offset + + new_text = new_text[len(prefix_text):] + return new_tokens, new_text, read_offset, len(output_tokens) diff --git a/vllm/wde/core/layers/__init__.py b/vllm/wde/core/layers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/layers/attention/__init__.py b/vllm/wde/core/layers/attention/__init__.py new file mode 100644 index 0000000000000..79bcc8ff6e9b5 --- /dev/null +++ b/vllm/wde/core/layers/attention/__init__.py @@ -0,0 +1,8 @@ +from vllm.wde.core.layers.attention.abstract import (AttentionBackend, + AttentionMetadata, + AttentionType) +from vllm.wde.core.layers.attention.layer import Attention + +__all__ = [ + "Attention", "AttentionMetadata", "AttentionBackend", "AttentionType" +] diff --git a/vllm/wde/core/layers/attention/abstract.py b/vllm/wde/core/layers/attention/abstract.py new file mode 100644 index 0000000000000..0e7ce7b71e4da --- /dev/null +++ b/vllm/wde/core/layers/attention/abstract.py @@ -0,0 +1,124 @@ +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, Dict, Generic, List, Optional, Type, TypeVar + +import torch + + +class AttentionType(Enum): + DECODER = auto() # Decoder attention between previous layer Q/K/V + ENCODER = auto() # Encoder attention between previous layer Q/K/V + ENCODER_DECODER = auto() # Attention between dec. Q and enc. K/V + + @staticmethod + def attn_type_name_to_enum(attn_type: str) -> "AttentionType": + assert attn_type is not None + + attn_type_members = AttentionType.__members__ + if attn_type not in attn_type_members: + raise ValueError( + f"Invalid attn_type '{attn_type}'. " + f"Available backends: {', '.join(attn_type_members)} " + "(case-sensitive).") + + return AttentionType[attn_type] + + +class AttentionBackend(ABC): + """Abstract class for attention backends.""" + + def __init__(self, attn_type: AttentionType): + self._attn_type = attn_type + + @property + def attn_type(self) -> AttentionType: + return self._attn_type + + @staticmethod + @abstractmethod + def get_name() -> str: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_impl_cls() -> Type["AttentionImpl"]: + raise NotImplementedError + + @staticmethod + @abstractmethod + def get_metadata_cls() -> Type["AttentionMetadata"]: + raise NotImplementedError + + @classmethod + def make_metadata(cls, *args, **kwargs) -> "AttentionMetadata": + return cls.get_metadata_cls()(*args, **kwargs) + + @staticmethod + @abstractmethod + def get_builder_cls() -> Type["AttentionMetadataBuilder"]: + raise NotImplementedError + + @classmethod + def make_metadata_builder(cls, *args, + **kwargs) -> "AttentionMetadataBuilder": + return cls.get_builder_cls()(*args, **kwargs) + + +@dataclass +class AttentionMetadata: + pass + + def to(self, device, non_blocking=False): + for k, v in self.__dict__.items(): + if isinstance(v, torch.Tensor): + self.__dict__[k] = v.to(device, non_blocking=non_blocking) + + return self + + +T = TypeVar("T", bound=AttentionMetadata) + + +class AttentionMetadataBuilder(ABC, Generic[T]): + """Abstract class for attention metadata builders.""" + + @abstractmethod + def __init__(self) -> None: + raise NotImplementedError + + @abstractmethod + def __call__(self, *args, **kwargs) -> T: + raise NotImplementedError + + +class AttentionImpl(ABC, Generic[T]): + + @abstractmethod + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + raise NotImplementedError + + @abstractmethod + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + raise NotImplementedError diff --git a/vllm/wde/core/layers/attention/layer.py b/vllm/wde/core/layers/attention/layer.py new file mode 100644 index 0000000000000..9d9e55a07aed6 --- /dev/null +++ b/vllm/wde/core/layers/attention/layer.py @@ -0,0 +1,101 @@ +"""Attention layer.""" +from typing import Any, Dict, List, Optional + +import torch +import torch.nn as nn + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.wde.core.layers.attention.abstract import AttentionBackend + + +class Attention(nn.Module): + """Attention layer. + + This class takes query, key, and value tensors as input. The input tensors + can either contain prompt tokens or generation tokens. + The class does the following: + + 1. Store the input key and value tensors in the KV cache. + 2. Perform (multi-head/multi-query/grouped-query) attention. + 3. Return the output tensor. + """ + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + attn_backend: AttentionBackend, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + if cache_config is not None: + kv_cache_dtype = cache_config.cache_dtype + # block_size = cache_config.block_size + sliding_window = cache_config.sliding_window + else: + kv_cache_dtype = "auto" + # block_size = 16 + sliding_window = None + if num_kv_heads is None: + num_kv_heads = num_heads + + # The default k/v_scale is set to 1.0. This is ignored + # when kv-cache is not fp8, and should be used with + # kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we + # expect the pre-quantized k/v_scale to be loaded along + # with the model weights. + self.kv_cache_dtype = kv_cache_dtype + self._k_scale = 1.0 + self._v_scale = 1.0 + quant_method = quant_config.get_quant_method( + self, prefix=prefix) if quant_config else None + if quant_method is not None: + assert isinstance(quant_method, BaseKVCacheMethod) + # TODO (mgoin): kv cache dtype should be specified in the FP8 + # checkpoint config and become the "auto" behavior + if self.kv_cache_dtype == "fp8_e5m2": + raise ValueError("fp8_e5m2 kv-cache is not supported with " + "fp8 checkpoints.") + # If quantization is enabled, we make "k_scale" and "v_scale" + # parameters so that it can be loaded from the model checkpoint. + # The k/v_scale will then be converted back to native float32 + # values after weight loading. + self.quant_method = quant_method + self.quant_method.create_weights(self) + + impl_cls = attn_backend.get_impl_cls() + self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + blocksparse_params, logits_soft_cap) + self.attn_type = attn_backend.attn_type + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + return self.impl.forward(query, key, value, attn_metadata, kv_cache, + self._k_scale, self._v_scale, self.attn_type) + + def extra_repr(self) -> str: + s = f"head_size={self.impl.head_size}" # type: ignore + s += f", num_heads={self.impl.num_heads}" # type: ignore + s += f", num_kv_heads={self.impl.num_kv_heads}" # type: ignore + s += f", scale={self.impl.scale}" # type: ignore + s += f", backend={self.impl.__class__.__name__}" + return s diff --git a/vllm/wde/core/llm_engine.py b/vllm/wde/core/llm_engine.py new file mode 100644 index 0000000000000..4f72337b34d1a --- /dev/null +++ b/vllm/wde/core/llm_engine.py @@ -0,0 +1,240 @@ +from contextlib import contextmanager +from queue import Empty, Queue +from typing import TYPE_CHECKING, ClassVar, Dict, Iterable, List, Optional +from typing import Sequence as GenericSequence +from typing import Type, Union + +from vllm.logger import init_logger +from vllm.wde.core.arg_utils import EngineArgs +from vllm.wde.core.config import EngineConfig +from vllm.wde.core.schema.engine_io import (Inputs, Params, RequestOutput, + ValidationError) +from vllm.wde.core.workflow import Workflow + +logger = init_logger(__name__) +_O = RequestOutput + + +def lazy_import(module): + module_name, class_name = module.split(":") + import importlib + module = importlib.import_module(module_name) + return getattr(module, class_name) + + +class LLMEngine: + DO_VALIDATE_OUTPUT: ClassVar[bool] = False + """A flag to toggle whether to validate the type of request output.""" + + @classmethod + @contextmanager + def enable_output_validation(cls): + cls.DO_VALIDATE_OUTPUT = True + + yield + + cls.DO_VALIDATE_OUTPUT = False + + @classmethod + def validate_output( + cls, + output: object, + output_type: Type[_O], + ) -> _O: + do_validate = cls.DO_VALIDATE_OUTPUT + + if ((TYPE_CHECKING or do_validate) + and not isinstance(output, output_type)): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + return output + + @classmethod + def validate_outputs( + cls, + outputs: GenericSequence[object], + output_type: Type[_O], + ) -> List[_O]: + do_validate = cls.DO_VALIDATE_OUTPUT + + outputs_: List[_O] + if TYPE_CHECKING or do_validate: + outputs_ = [] + for output in outputs: + if not isinstance(output, output_type): + raise TypeError(f"Expected output of type {output_type}, " + f"but found type {type(output)}") + + outputs_.append(output) + else: + outputs_ = outputs + + return outputs_ + + def __init__(self, engine_config: EngineConfig, + workflow_cls: Type[Workflow]) -> None: + self.engine_config = engine_config + self.engine_config.log_config() + self.workflow = workflow_cls.from_engine(self) + + self._maybe_init_async_scheduling() + + self.attn_backend = lazy_import( + self.workflow.AttnBackend).from_engine(self) + self.executor = lazy_import(self.workflow.Executor).from_engine(self) + self.tokenizer = lazy_import(self.workflow.Tokenizer).from_engine(self) + self.model_inputs_builder = lazy_import( + self.workflow.ModelInputBuilder).from_engine(self) + + if hasattr(self.executor, "initialize_kv_caches"): + self.executor.initialize_kv_caches(self) + + self.input_processor = lazy_import( + self.workflow.InputProcessor).from_engine(self) + self.request_processor = lazy_import( + self.workflow.RequestProcessor).from_engine(self) + self.scheduler = lazy_import(self.workflow.Scheduler).from_engine(self) + self.output_processor = lazy_import( + self.workflow.OutputProcessor).from_engine(self) + + def _maybe_init_async_scheduling(self): + executor_cls = lazy_import(self.workflow.Executor) + scheduler_cls = lazy_import(self.workflow.Scheduler) + + if ("async_scheduling" in executor_cls.support_scheduling + and "async_scheduling" in scheduler_cls.support_scheduling): + logger.info("Use async scheduling") + self.use_async_scheduling = True + + elif ("sync_scheduling" in executor_cls.support_scheduling + and "sync_scheduling" in scheduler_cls.support_scheduling): + logger.info("Use sync scheduling") + self.use_async_scheduling = False + + else: + raise RuntimeError(f"Executor support scheduling: " + f"{executor_cls.support_scheduling}." + f"Scheduler support scheduling: " + f"{executor_cls.support_scheduling}." + f"Not compatible") + + if self.use_async_scheduling: + self.executor_in = Queue() + self.executor_out = Queue() + self.max_num_on_the_fly = ( + self.engine_config.scheduler_config.max_num_on_the_fly) + self.num_on_the_fly = 0 + self.step = self.async_step + else: + self.step = self.sync_step + + @classmethod + def from_engine_args(cls, engine_args: Union[Dict, + EngineArgs]) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + + from vllm.transformers_utils.config import get_config + from vllm.wde.core.loader.utils import get_model_workflow + + if isinstance(engine_args, EngineArgs): + engine_args = engine_args.to_dict() + + hf_config = get_config(engine_args["model"], + engine_args.get("trust_remote_code", False), + engine_args.get("revision", None), + engine_args.get("code_revision", None)) + + workflow_class = get_model_workflow(hf_config) + workflow = lazy_import(workflow_class) + + engine_args = lazy_import(workflow.EngineArgs)(**engine_args) + + engine_config = engine_args.create_engine_config() + engine = cls(engine_config, workflow) + return engine + + def add_request(self, + request_id: str, + inputs: Optional[Union[str, Inputs]] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> None: + try: + request = self.input_processor(request_id, inputs, params, + arrival_time) + except ValidationError: + logger.error("%s validation error", request_id) + return + self.scheduler.add_request(request) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + self.scheduler.abort_request(request_id) + + def sync_step(self) -> List[RequestOutput]: + scheduler_output = self.scheduler.schedule() + if scheduler_output.is_empty(): + return [] + + executor_input = self.model_inputs_builder(scheduler_output) + executor_output = self.executor.execute_model(executor_input) + request_outputs = self.output_processor(scheduler_output, + executor_output) + self.scheduler.free_finished_request(request_outputs) + request_outputs = self.scheduler.remove_abort_request(request_outputs) + return request_outputs + + def async_step(self) -> List[RequestOutput]: + self.executor.ensure_start_execute_loop() + self._put_as_many_as_possible() + + if self.num_on_the_fly == 0: + return [] + + return self._get(block=True) + + def _put_as_many_as_possible(self): + while self.num_on_the_fly < self.max_num_on_the_fly: + scheduler_output = self.scheduler.schedule() + if scheduler_output.is_empty(): + break + executor_input = self.model_inputs_builder(scheduler_output) + + self.executor_in.put((scheduler_output, executor_input)) + self.num_on_the_fly += 1 + + def _get(self, block): + try: + scheduler_output, executor_output = self.executor_out.get(block) + except Empty: + return + + self.num_on_the_fly -= 1 + + # Theoretically, this put is not needed + # practically, task can be inqueue before doing post-processing + self._put_as_many_as_possible() + + request_outputs = self.output_processor(scheduler_output, + executor_output) + self.scheduler.free_finished_request(request_outputs) + request_outputs = self.scheduler.remove_abort_request(request_outputs) + return request_outputs + + def get_num_unfinished_requests(self) -> int: + """Gets the number of unfinished requests.""" + return self.scheduler.get_num_unfinished_requests() + + def has_unfinished_requests(self) -> bool: + """Returns True if there are unfinished requests.""" + return self.scheduler.has_unfinished_requests() + + def __reduce__(self): + # This is to ensure that the LLMEngine is not referenced in + # the closure used to initialize Ray worker actors + raise RuntimeError("LLMEngine should not be pickled!") + + def __del__(self): + # Shutdown model executor when engine is garbage collected + # Use getattr since __init__ can fail before the field is set + if executor := getattr(self, "executor", None): + executor.shutdown_execute_loop() diff --git a/vllm/wde/core/loader/__init__.py b/vllm/wde/core/loader/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/loader/loader.py b/vllm/wde/core/loader/loader.py new file mode 100644 index 0000000000000..ed26f73b5a4fb --- /dev/null +++ b/vllm/wde/core/loader/loader.py @@ -0,0 +1,631 @@ +# ruff: noqa: SIM117 +import fnmatch +import glob +import json +import math +import os +from abc import ABC, abstractmethod +from contextlib import contextmanager +from typing import Any, Dict, Generator, List, Optional, Tuple + +import huggingface_hub +import numpy as np +import torch +from huggingface_hub import HfApi, hf_hub_download +from torch import nn +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME + +from vllm.config import CacheConfig, ModelConfig, SchedulerConfig +from vllm.envs import VLLM_USE_MODELSCOPE +from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + download_safetensors_index_file_from_hf, download_weights_from_hf, + filter_duplicate_safetensors_files, filter_files_not_needed_for_inference, + get_quant_config, np_cache_weights_iterator, pt_weights_iterator, + safetensors_weights_iterator) +from vllm.model_executor.utils import set_weight_attrs +from vllm.utils import is_pin_memory_available +from vllm.wde.core.config import DeviceConfig, LoadConfig, LoadFormat +from vllm.wde.core.layers.attention.abstract import AttentionBackend +from vllm.wde.core.loader.utils import (get_model_architecture, + set_default_torch_dtype) + + +@contextmanager +def device_loading_context(module: torch.nn.Module, + target_device: torch.device): + if target_device.type == "cpu": + # If target is CPU, no need to move anything + yield module + return + + original_device_states: Dict[str, torch.device] = {} + + # Store original device states and move parameters to GPU if they're on CPU + for name, p in module.named_parameters(): + if p.device.type == "cpu": + original_device_states[name] = p.device + p.data = p.data.to(target_device) + # Parameters already on target device are not touched + + try: + yield module + + finally: + # Restore parameters to their original devices, ignoring new parameters + pin_memory = is_pin_memory_available() + for name, p in module.named_parameters(): + if name in original_device_states: + original_device: torch.device = original_device_states[name] + if original_device.type == "cpu": + # `torch.empty_like` does not support `pin_memory` argument + cpu_data = torch.empty_strided(size=p.data.size(), + stride=p.data.stride(), + dtype=p.data.dtype, + layout=p.data.layout, + device="cpu", + pin_memory=pin_memory) + cpu_data.copy_(p.data) + p.data = cpu_data + else: + p.data = p.data.to(original_device) + # New parameters or parameters already on target device are untouched + + +logger = init_logger(__name__) + + +def _get_quantization_config( + model_config: ModelConfig, + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" + if model_config.quantization is not None: + quant_config = get_quant_config(model_config, load_config) + capability = torch.cuda.get_device_capability() + capability = capability[0] * 10 + capability[1] + if capability < quant_config.get_min_capability(): + raise ValueError( + f"The quantization method {model_config.quantization} is not " + "supported for the current GPU. " + f"Minimum capability: {quant_config.get_min_capability()}. " + f"Current capability: {capability}.") + supported_dtypes = quant_config.get_supported_act_dtypes() + if model_config.dtype not in supported_dtypes: + raise ValueError( + f"{model_config.dtype} is not supported for quantization " + f"method {model_config.quantization}. Supported dtypes: " + f"{supported_dtypes}") + return quant_config + return None + + +def initialize_model( + model_config: ModelConfig, + load_config: LoadConfig, + device_config: DeviceConfig, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, +) -> nn.Module: + """Initialize a model with the given configurations.""" + + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + with target_device: + model_class = get_model_architecture(model_config)[0] + quant_config = _get_quantization_config(model_config, load_config) + + return model_class(config=model_config.hf_config, + cache_config=cache_config, + quant_config=quant_config, + attn_backend=attn_backend) + + +class BaseModelLoader(ABC): + """Base class for model loaders.""" + + def __init__(self, load_config: LoadConfig): + self.load_config = load_config + + @abstractmethod + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + """Load a model with the given configurations.""" + ... + + +class DefaultModelLoader(BaseModelLoader): + """Model loader that can load different file types from disk.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def _maybe_download_from_modelscope( + self, model: str, revision: Optional[str]) -> Optional[str]: + """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True. + + Returns the path to the downloaded model, or None if the model is not + downloaded from ModelScope.""" + if VLLM_USE_MODELSCOPE: + # download model from ModelScope hub, + # lazy import so that modelscope is not required for normal use. + # pylint: disable=C. + from modelscope.hub.snapshot_download import snapshot_download + + if not os.path.exists(model): + model_path = snapshot_download( + model_id=model, + cache_dir=self.load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ignore_file_pattern=self.load_config.ignore_patterns, + ) + else: + model_path = model + return model_path + return None + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str], + fall_back_to_pt: bool) -> Tuple[str, List[str], bool]: + """Prepare weights for the model. + + If the model is not local, it will be downloaded.""" + model_name_or_path = self._maybe_download_from_modelscope( + model_name_or_path, revision) or model_name_or_path + + is_local = os.path.isdir(model_name_or_path) + load_format = self.load_config.load_format + use_safetensors = False + index_file = SAFE_WEIGHTS_INDEX_NAME + # Some quantized models use .pt files for storing the weights. + if load_format == LoadFormat.AUTO: + allow_patterns = ["*.safetensors", "*.bin"] + elif load_format == LoadFormat.SAFETENSORS: + use_safetensors = True + allow_patterns = ["*.safetensors"] + elif load_format == LoadFormat.MISTRAL: + use_safetensors = True + allow_patterns = ["consolidated*.safetensors"] + index_file = "consolidated.safetensors.index.json" + elif load_format == LoadFormat.PT: + allow_patterns = ["*.pt"] + elif load_format == LoadFormat.NPCACHE: + allow_patterns = ["*.bin"] + else: + raise ValueError(f"Unknown load_format: {load_format}") + + if fall_back_to_pt: + allow_patterns += ["*.pt"] + + if not is_local: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + allow_patterns, + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + else: + hf_folder = model_name_or_path + + hf_weights_files: List[str] = [] + for pattern in allow_patterns: + hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) + if len(hf_weights_files) > 0: + if pattern == "*.safetensors": + use_safetensors = True + break + + if use_safetensors: + # For models like Mistral-7B-Instruct-v0.3 + # there are both sharded safetensors files and a consolidated + # safetensors file. Using both breaks. + # Here, we download the `model.safetensors.index.json` and filter + # any files not found in the index. + if not is_local: + download_safetensors_index_file_from_hf( + model_name_or_path, index_file, + self.load_config.download_dir, revision) + hf_weights_files = filter_duplicate_safetensors_files( + hf_weights_files, hf_folder, index_file) + else: + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_folder, hf_weights_files, use_safetensors + + def _get_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], + fall_back_to_pt: bool + ) -> Generator[Tuple[str, torch.Tensor], None, None]: + """Get an iterator for the model weights based on the load format.""" + hf_folder, hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision, fall_back_to_pt) + if self.load_config.load_format == LoadFormat.NPCACHE: + # Currently np_cache only support *.bin checkpoints + assert use_safetensors is False + weights_iterator = np_cache_weights_iterator( + model_name_or_path, self.load_config.download_dir, hf_folder, + hf_weights_files) + elif use_safetensors: + weights_iterator = safetensors_weights_iterator(hf_weights_files) + else: + weights_iterator = pt_weights_iterator(hf_weights_files) + + return weights_iterator + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + target_device = torch.device(device_config.device) + with set_default_torch_dtype(model_config.dtype): + model.load_weights( + self._get_weights_iterator(model_config.model, + model_config.revision, + fall_back_to_pt=getattr( + model, + "fall_back_to_pt_during_load", + True)), ) + + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + # When quant methods need to process weights after loading + # (for repacking, quantizing, etc), they expect parameters + # to be on the global target device. This scope is for the + # case where cpu offloading is used, where we will move the + # parameters onto device for processing and back off after. + with device_loading_context(module, target_device): + quant_method.process_weights_after_loading(module) + return model.eval() + + +class DummyModelLoader(BaseModelLoader): + """Model loader that will set model weights to random values.""" + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + if load_config.model_loader_extra_config: + raise ValueError(f"Model loader extra config is not supported for " + f"load format {load_config.load_format}") + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + return model.eval() + + +class BitsAndBytesModelLoader(BaseModelLoader): + """Model loader to load model weights with BitAndBytes quantization.""" + + default_target_modules = [ + "gate_proj", "down_proj", "up_proj", "q_proj", "k_proj", "v_proj", + "o_proj" + ] + + possible_config_file_names = ["adapter_config.json"] + + def __init__(self, load_config: LoadConfig): + super().__init__(load_config) + + # we don't need to quantize the whole model, only the target modules + # that are specified in the adapter config file. If the adapter config + # file is not provided, we will quantize the default modules. + if (not load_config.model_loader_extra_config + or "qlora_adapter_name_or_path" + not in load_config.model_loader_extra_config): + self.target_modules = self.default_target_modules + return + + qlora_adapter = load_config.model_loader_extra_config[ + "qlora_adapter_name_or_path"] + + config_file_path = self._get_config_file(qlora_adapter) + + with open(config_file_path, "r") as f: + config = json.load(f) + self.target_modules = config["target_modules"] + + def _get_config_file(self, qlora_adapter: str) -> str: + is_local = os.path.isdir(qlora_adapter) + config_file_path = None + if is_local: + for file in self.possible_config_file_names: + config_file_path = os.path.join(qlora_adapter, file) + if os.path.exists(config_file_path): + break + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=qlora_adapter) + for file in self.possible_config_file_names: + if file in repo_files: + config_file_path = hf_hub_download(repo_id=qlora_adapter, + filename=file) + break + + if not config_file_path: + raise ValueError( + f"Cannot find adapter config file in {qlora_adapter}") + + return config_file_path + + def _get_weight_files( + self, + model_name_or_path: str, + allowed_patterns: List[str], + revision: Optional[str] = None) -> Tuple[List[str], str]: + """Retrieve weight files. Download the files if necessary. + + Return the weight files and the file pattern.""" + is_local = os.path.isdir(model_name_or_path) + + if is_local: + for pattern in allowed_patterns: + weight_files = glob.glob( + os.path.join(model_name_or_path, pattern)) + if weight_files: + return weight_files, pattern + else: + hf_api = HfApi() + repo_files = hf_api.list_repo_files(repo_id=model_name_or_path) + for pattern in allowed_patterns: + matching_files = fnmatch.filter(repo_files, pattern) + if matching_files: + hf_folder = download_weights_from_hf( + model_name_or_path, + self.load_config.download_dir, + [pattern], + revision, + ignore_patterns=self.load_config.ignore_patterns, + ) + return glob.glob(os.path.join(hf_folder, pattern)), pattern + + raise RuntimeError( + f"No model weights found in: `{model_name_or_path}`") + + def _prepare_weights(self, model_name_or_path: str, + revision: Optional[str]) -> Tuple[List[str], bool]: + """Prepare weight files for the model.""" + + allowed_patterns = ["*.safetensors", "*.bin", "*.pt"] + + hf_weights_files, matched_pattern = self._get_weight_files( + model_name_or_path, allowed_patterns, revision) + + if matched_pattern != "*.safetensors": + hf_weights_files = filter_files_not_needed_for_inference( + hf_weights_files) + + if len(hf_weights_files) == 0: + raise RuntimeError( + f"Cannot find any model weights with `{model_name_or_path}`") + + return hf_weights_files, matched_pattern == "*.safetensors" + + def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool): + if use_safetensors: + return safetensors_weights_iterator(hf_weights_files) + else: + return pt_weights_iterator(hf_weights_files) + + def _get_quantized_weights_iterator( + self, model_name_or_path: str, revision: Optional[str], pre_quant: bool + ) -> Tuple[Generator[Tuple[str, torch.Tensor], None, None], Dict[str, + Any]]: + """Get an iterator to the model weights with bitsandbytes quantization, + as well as the quantization state dictionary.""" + + # only load the bitsandbytes module when needed + try: + import bitsandbytes + from bitsandbytes.functional import QuantState + if bitsandbytes.__version__ < "0.42.0": + raise ImportError("bitsandbytes version is wrong. Please " + "install bitsandbytes>=0.42.0.") + from bitsandbytes.functional import quantize_4bit + except ImportError as err: + raise ImportError("Please install bitsandbytes>=0.42.0 via " + "`pip install bitsandbytes>=0.42.0` to use " + "bitsandbytes quantizer.") from err + + hf_weights_files, use_safetensors = self._prepare_weights( + model_name_or_path, revision) + + quant_state_dict = {} + + def quantized_checkpoint() -> Generator: + # First iterate over all quant state weights + weight_iterator = self._hf_weight_iter(hf_weights_files, + use_safetensors) + temp_state_dict = {} + for weight_name, weight_tensor in weight_iterator: + if weight_name.endswith(".weight"): + continue + # TODO: only nf4 quantization is supported for now + if weight_name.endswith(".quant_state.bitsandbytes__fp4"): + raise NotImplementedError( + "Only bitsandbytes_nf4 quantization" + f"is supported for now. {weight_name} is fp4 quantized" + ) + temp_state_dict[weight_name] = weight_tensor + + # Closure to parse quant_state for each prequant weight + def _parse_quant_state(param_name: str, + temp_state_dict: Dict) -> QuantState: + quant_state = {} + for k in temp_state_dict: + if param_name + "." in k: + quant_state[k] = temp_state_dict[k] + # bitsandbytes library requires + # weight.quant_state.bitsandbytes__nf4 in CPU + quant_state[param_name + + ".quant_state.bitsandbytes__nf4"] = quant_state[ + param_name + + ".quant_state.bitsandbytes__nf4"].cpu().data + return QuantState.from_dict(quant_state, device="cuda") + + # Second iterate over all prequant and normal weights + # pre quantized weights would have a quant_state + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + # Filter out all weights whose suffix is not ".weight" + if not weight_name.endswith(".weight"): + continue + if weight_name + ".quant_state.bitsandbytes__nf4" \ + in temp_state_dict: + quant_state = _parse_quant_state(weight_name, + temp_state_dict) + weight_name = weight_name.replace(".weight", ".qweight") + quant_state_dict[weight_name] = quant_state + yield weight_name.replace(".weight", + ".qweight"), weight_tensor + else: + yield weight_name, weight_tensor + + def generator() -> Generator: + for weight_name, weight_tensor in self._hf_weight_iter( + hf_weights_files, use_safetensors): + if any(target_module in weight_name + for target_module in self.target_modules): + weight_name = weight_name.replace(".weight", ".qweight") + # bitsandbytes requires data in GPU + loaded_weight = weight_tensor.cuda().data + with set_default_torch_dtype(torch.float32): + processed_weight, quant_state = quantize_4bit( + loaded_weight, + compress_statistics=True, + quant_type="nf4") + + quant_state_dict[weight_name] = quant_state + else: + processed_weight = weight_tensor + + yield weight_name, processed_weight + + if pre_quant: + return quantized_checkpoint(), quant_state_dict + return generator(), quant_state_dict + + def _load_weights(self, model_config: ModelConfig, + model: nn.Module) -> None: + if not hasattr(model, 'load_weights'): + raise AttributeError( + "The required method 'load_weights' is not defined in class" + f" {type(self).__name__}.") + + if not hasattr(model, 'bitsandbytes_stacked_params_mapping'): + raise AttributeError( + f"Model {type(self).__name__} does not support BitsAndBytes " + "quantization yet.") + + logger.info("Loading weights with BitsAndBytes quantization. " + " May take a while ...") + + is_quantized_checkpoint = False + quant_config = getattr(model_config.hf_config, "quantization_config", + None) + if quant_config is not None and quant_config.get( + 'quant_method') == "bitsandbytes": + is_quantized_checkpoint = True + + qweight_iterator, quant_state_dict = \ + self._get_quantized_weights_iterator( + model_config.model, model_config.revision, is_quantized_checkpoint) + + model.load_weights(qweight_iterator) + + torch.cuda.empty_cache() + + param_dict = dict(model.named_parameters()) + stacked_quant_state_dict: Dict[str, Dict[int, Any]] = {} + for quant_param_name in quant_state_dict: + non_stacked_param_name = quant_param_name + + shard_index = 0 + for shard_name, ( + weight_name, index + ) in model.bitsandbytes_stacked_params_mapping.items(): + if shard_name in quant_param_name: + shard_index = index + quant_param_name = quant_param_name.replace( + shard_name, weight_name) + break + + if quant_param_name not in param_dict: + raise ValueError( + f"Parameter {quant_param_name} not found in the model.") + + if quant_param_name not in stacked_quant_state_dict: + stacked_quant_state_dict[quant_param_name] = {} + + stacked_quant_state_dict[quant_param_name][shard_index] = ( + quant_state_dict[non_stacked_param_name]) + + # save quant_states and offsets as the attributes of the parameters + for param_name, param in param_dict.items(): + if param_name in stacked_quant_state_dict: + quant_states = stacked_quant_state_dict[param_name] + set_weight_attrs(param, {"bnb_quant_state": quant_states}) + + pack_ratio = getattr(param, "pack_factor", -1) + if pack_ratio == -1: + raise ValueError( + f"pack_factor not set for parameter {param_name}.") + + num_elements = [0] * len(quant_states) + for seq, quant_state in quant_states.items(): + num_elements[seq] = math.prod( + quant_state.shape) // pack_ratio + + offsets = np.concatenate(([0], np.cumsum(num_elements))) + set_weight_attrs(param, {"bnb_shard_offsets": offsets}) + + def load_model(self, + model: nn.Module, + *, + model_config: ModelConfig, + device_config: DeviceConfig, + scheduler_config: Optional[SchedulerConfig] = None, + cache_config: Optional[CacheConfig] = None) -> nn.Module: + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + self._load_weights(model_config, model) + + return model.eval() + + +def get_model_loader(load_config: LoadConfig) -> BaseModelLoader: + """Get a model loader based on the load format.""" + + if isinstance(load_config.load_format, type): + return load_config.load_format(load_config) + + if load_config.load_format == LoadFormat.DUMMY: + return DummyModelLoader(load_config) + + if load_config.load_format == LoadFormat.BITSANDBYTES: + return BitsAndBytesModelLoader(load_config) + + return DefaultModelLoader(load_config) diff --git a/vllm/wde/core/loader/utils.py b/vllm/wde/core/loader/utils.py new file mode 100644 index 0000000000000..c03c1acaf7dbd --- /dev/null +++ b/vllm/wde/core/loader/utils.py @@ -0,0 +1,48 @@ +"""Utilities for selecting and loading models.""" +import contextlib +from typing import Tuple, Type + +import torch +from torch import nn +from transformers import PretrainedConfig + +from vllm.wde.core.config import ModelConfig +from vllm.wde.core.modelzoo import ModelRegistry + + +@contextlib.contextmanager +def set_default_torch_dtype(dtype: torch.dtype): + """Sets the default torch dtype to the given dtype.""" + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(dtype) + yield + torch.set_default_dtype(old_dtype) + + +def get_model_architecture( + model_config: ModelConfig) -> Tuple[Type[nn.Module], str]: + architectures = getattr(model_config.hf_config, "architectures", []) + + for arch in architectures: + model_cls = ModelRegistry.load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_model_workflow(hf_config: PretrainedConfig) -> str: + architectures = getattr(hf_config, "architectures", []) + + for arch in architectures: + workflow = ModelRegistry.get_workflow(arch) + if workflow is not None: + return workflow + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + +def get_architecture_class_name(model_config: ModelConfig) -> str: + return get_model_architecture(model_config)[1] diff --git a/vllm/wde/core/modelzoo.py b/vllm/wde/core/modelzoo.py new file mode 100644 index 0000000000000..66c9923019356 --- /dev/null +++ b/vllm/wde/core/modelzoo.py @@ -0,0 +1,69 @@ +import functools +import importlib +from typing import Dict, List, Optional, Type + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.wde.decode_only.modelzoo import DECODE_ONLY_MODELS +from vllm.wde.encode_only.modelzoo import ENCODE_ONLY_MODELS +from vllm.wde.reranker.modelzoo import RERANKER_MODELS +from vllm.wde.retriever.modelzoo import RETRIEVER_MODELS + +logger = init_logger(__name__) + +_MODELS_LIST = [ + ENCODE_ONLY_MODELS, RETRIEVER_MODELS, RERANKER_MODELS, DECODE_ONLY_MODELS +] + +_MODELS = dict() +for m in _MODELS_LIST: + _MODELS.update(**m) + +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + + +class ModelRegistry: + + @staticmethod + @functools.lru_cache(maxsize=128) + def _get_model(model_arch: str): + task, module_name, model_cls_name, workflow = _MODELS[model_arch] + module = importlib.import_module( + f"vllm.wde.{task}.modelzoo.{module_name}") + return getattr(module, model_cls_name, None) + + @staticmethod + def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + if model_arch not in _MODELS: + return None + return ModelRegistry._get_model(model_arch) + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + + @staticmethod + @functools.lru_cache(maxsize=128) + def get_workflow(model_arch: str): + task, module_name, model_cls_name, workflow = _MODELS[model_arch] + return workflow + + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) + global _OOT_MODELS + _OOT_MODELS[model_arch] = model_cls + + +__all__ = [ + "ModelRegistry", +] diff --git a/vllm/wde/core/processor/__init__.py b/vllm/wde/core/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/processor/input_processor.py b/vllm/wde/core/processor/input_processor.py new file mode 100644 index 0000000000000..66621dde67c4a --- /dev/null +++ b/vllm/wde/core/processor/input_processor.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Optional, Union + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.schema.engine_io import (Inputs, Params, Request, + SchedulableRequest) + + +class InputProcessor(ABC): + """ + Input(request_id, inputs, params, arrival_time) -> InputProcessor -> Request + """ + + @abstractmethod + def __call__(self, + request_id: str, + inputs: Optional[Union[str, Inputs]] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> Request: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine: LLMEngine): + raise NotImplementedError + + +class RequestProcessor(ABC): + """ + Request -> RequestProcessor -> SchedulableRequest + """ + + @abstractmethod + def __call__(self, request: Request) -> SchedulableRequest: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine: LLMEngine): + raise NotImplementedError diff --git a/vllm/wde/core/processor/model_input_builder.py b/vllm/wde/core/processor/model_input_builder.py new file mode 100644 index 0000000000000..f5ccb33f62d9e --- /dev/null +++ b/vllm/wde/core/processor/model_input_builder.py @@ -0,0 +1,21 @@ +from abc import ABC, abstractmethod + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.schema.engine_io import SchedulerOutput +from vllm.wde.core.schema.execute_io import ExecuteInput + + +class ModelInputBuilder(ABC): + """ + scheduler_output = scheduler.schedule() + SchedulerOutput -> ModelInputBuilder -> ExecuteInput + """ + + @abstractmethod + def __call__(self, scheduler_output: SchedulerOutput) -> ExecuteInput: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine: LLMEngine): + raise NotImplementedError diff --git a/vllm/wde/core/processor/output_processor.py b/vllm/wde/core/processor/output_processor.py new file mode 100644 index 0000000000000..05ab3ab9b5ea0 --- /dev/null +++ b/vllm/wde/core/processor/output_processor.py @@ -0,0 +1,23 @@ +from abc import ABC, abstractmethod +from typing import List + +import torch + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.schema.engine_io import RequestOutput, SchedulerOutput + + +class OutputProcessor(ABC): + """ + scheduler_output, execute_output -> OutputProcessor -> RequestOutput + """ + + @abstractmethod + def __call__(self, scheduler_output: SchedulerOutput, + execute_output: torch.Tensor) -> List[RequestOutput]: + raise NotImplementedError + + @classmethod + @abstractmethod + def from_engine(cls, engine: LLMEngine): + raise NotImplementedError diff --git a/vllm/wde/core/scheduler.py b/vllm/wde/core/scheduler.py new file mode 100644 index 0000000000000..1551baf5b157d --- /dev/null +++ b/vllm/wde/core/scheduler.py @@ -0,0 +1,85 @@ +from abc import ABC, abstractmethod +from collections import deque +from typing import Deque, Iterable, List, Union + +from vllm.logger import init_logger +from vllm.wde.core.config import SchedulerConfig +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.input_processor import RequestProcessor +from vllm.wde.core.schema.engine_io import (Request, RequestOutput, + SchedulerOutput) + +logger = init_logger(__name__) + + +class Scheduler(ABC): + support_scheduling = [] + + def __init__( + self, + scheduler_config: SchedulerConfig, + request_processor: RequestProcessor, + ) -> None: + self.scheduler_config = scheduler_config + self.request_processor = request_processor + + self.waiting: Deque[Request] = deque() + + self.requests = set() + self.aborted_requests = set() + + @classmethod + def from_engine(cls, engine: LLMEngine) -> "Scheduler": + raise NotImplementedError + + def add_request(self, request: Request) -> None: + if (request.request_id in self.requests + or request.request_id in self.aborted_requests): + logger.warning("[%s] request_id conflict") + return + + self.waiting.append(request) + self.requests.add(request.request_id) + + def abort_request(self, request_id: Union[str, Iterable[str]]) -> None: + if isinstance(request_id, str): + request_id = (request_id, ) + request_ids = set(request_id) + + self.requests -= request_ids + self.aborted_requests += request_ids + + def remove_abort_request( + self, request_outputs: List[RequestOutput]) -> List[RequestOutput]: + if len(self.aborted_requests) == 0: + return request_outputs + + current_ids = set(request.request_id for request in request_outputs) + need_abort = self.aborted_requests & current_ids + + if len(need_abort) == 0: + return request_outputs + + request_outputs = [ + request for request in request_outputs + if request.request_id not in need_abort + ] + self.aborted_requests -= need_abort + + return request_outputs + + def has_unfinished_requests(self) -> bool: + return len(self.requests) != 0 + + def get_num_unfinished_requests(self) -> int: + return len(self.requests) + + @abstractmethod + def schedule(self) -> SchedulerOutput: + raise NotImplementedError + + def free_finished_request(self, request_outputs: List[RequestOutput]): + finished_request_ids = set(request.request_id + for request in request_outputs + if request.finished) + self.requests -= finished_request_ids diff --git a/vllm/wde/core/schema/__init__.py b/vllm/wde/core/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/core/schema/engine_io.py b/vllm/wde/core/schema/engine_io.py new file mode 100644 index 0000000000000..e867f91a31861 --- /dev/null +++ b/vllm/wde/core/schema/engine_io.py @@ -0,0 +1,63 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + + +class Params: + pass + + +class Inputs: + pass + + +@dataclass +class Request: + request_id: str + arrival_time: float + + +@dataclass +class TextPrompt(Inputs): + """Schema for a text prompt.""" + + prompt: str + """The input text to be tokenized before passing to the model.""" + + +@dataclass +class TokensPrompt(Inputs): + """Schema for a tokenized prompt.""" + + prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + +PromptInput = Union[str, TextPrompt, TokensPrompt] + + +@dataclass +class TextOnlyInputs(Inputs): + prompt_token_ids: List[int] + """The token IDs of the prompt.""" + + prompt: Optional[str] + """ + The original prompt text corresponding to the token IDs, if available. + """ + + +class SchedulableRequest(Request): + pass + + +@dataclass +class SchedulerOutput: + pass + + +class RequestOutput(Request): + finished: bool + + +class ValidationError(ValueError): + pass diff --git a/vllm/wde/core/schema/execute_io.py b/vllm/wde/core/schema/execute_io.py new file mode 100644 index 0000000000000..67d91a8b078dd --- /dev/null +++ b/vllm/wde/core/schema/execute_io.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass +from typing import Optional + + +@dataclass +class ModelInput: + pass + + +@dataclass +class WorkerInput: + pass + + +@dataclass +class ExecuteInput: + worker_input: Optional[WorkerInput] + model_input: Optional[ModelInput] + + +class ExecuteOutput: + pass diff --git a/vllm/wde/core/worker.py b/vllm/wde/core/worker.py new file mode 100644 index 0000000000000..7549397219e50 --- /dev/null +++ b/vllm/wde/core/worker.py @@ -0,0 +1,104 @@ +import importlib +import os +from abc import ABC, abstractmethod +from typing import Callable, Dict, Optional, Type + +from vllm.logger import init_logger +from vllm.utils import (enable_trace_function_call_for_thread, + update_environment_variables) +from vllm.wde.core.schema.execute_io import ExecuteInput + +logger = init_logger(__name__) + + +class WorkerBase(ABC): + + @abstractmethod + def __call__(self, execute_input: ExecuteInput): + raise NotImplementedError + + +class WorkerWrapperBase: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + + If worker_class_fn is specified, it will be executed to get the worker + class. + Otherwise, the worker class will be obtained by dynamically importing it + using worker_module_name and worker_class_name. + """ + + def __init__( + self, + worker_module_name: str, + worker_class_name: str, + trust_remote_code: bool = False, + worker_class_fn: Optional[Callable[[], + Type[WorkerBase]]] = None) -> None: + self.worker_module_name = worker_module_name + self.worker_class_name = worker_class_name + self.worker_class_fn = worker_class_fn + self.worker: Optional[WorkerBase] = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + enable_trace_function_call_for_thread() + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + if self.worker_class_fn: + worker_class = self.worker_class_fn() + else: + mod = importlib.import_module(self.worker_module_name) + worker_class = getattr(mod, self.worker_class_name) + + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + +def create_worker(module, envs=None, **kwargs): + module_name, class_name = module.split(":") + wrapper = WorkerWrapperBase( + worker_module_name=module_name, + worker_class_name=class_name, + ) + if envs: + wrapper.update_environment_variables(envs) + + wrapper.init_worker(**kwargs) + return wrapper.worker \ No newline at end of file diff --git a/vllm/wde/core/workflow.py b/vllm/wde/core/workflow.py new file mode 100644 index 0000000000000..b12d6e4ac4185 --- /dev/null +++ b/vllm/wde/core/workflow.py @@ -0,0 +1,16 @@ +class Workflow: + EngineArgs: str + Scheduler: str + AttnBackend: str + attn_type: str + Tokenizer: str = "vllm.wde.core.inputs.tokenizer:Tokenizer" + InputProcessor: str + RequestProcessor: str + OutputProcessor: str + ModelInputBuilder: str + Executor: str + Worker: str + + @classmethod + def from_engine(cls, engine): + return cls() diff --git a/vllm/wde/decode_only/__init__.py b/vllm/wde/decode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/decode_only/arg_utils.py b/vllm/wde/decode_only/arg_utils.py new file mode 100644 index 0000000000000..14c69c2aef780 --- /dev/null +++ b/vllm/wde/decode_only/arg_utils.py @@ -0,0 +1,119 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.wde.core.arg_utils import EngineArgs +from vllm.wde.core.config import (DeviceConfig, LoadConfig, + filter_unexpected_fields) +from vllm.wde.decode_only.config import (DecodeOnlyEmbeddingSchedulerConfig, + DecodeOnlyEngineConfig, + DecodeOnlyModelConfig, + DecodeOnlySchedulerConfig, + PrefillOnlyParallelConfig) + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@filter_unexpected_fields +@dataclass +class DecodeOnlyEngineArgs(EngineArgs): + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None + disable_sliding_window: bool = False + seed: int = 0 + max_model_len: Optional[int] = None + max_num_batched_tokens: Optional[int] = None + + output_last_hidden_states: bool = False + enable_bidirectional: bool = False + + max_num_seqs: int = 256 + max_num_on_the_fly: int = 3 + scheduling: str = "async" + + data_parallel_size: int = 0 + + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + disable_custom_all_reduce: bool = False + device: str = 'auto' + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + def create_engine_config(self) -> DecodeOnlyEngineConfig: + device_config = DeviceConfig(device=self.device) + model_config = DecodeOnlyModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name, + output_last_hidden_states=self.output_last_hidden_states, + enable_bidirectional=self.enable_bidirectional) + + if model_config.output_last_hidden_states: + scheduler_config = DecodeOnlyEmbeddingSchedulerConfig( + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + max_num_on_the_fly=self.max_num_on_the_fly, + scheduling=self.scheduling) + else: + scheduler_config = DecodeOnlySchedulerConfig() + + if (model_config.output_last_hidden_states + and self.data_parallel_size > 0): + parallel_config = PrefillOnlyParallelConfig( + data_parallel_size=self.data_parallel_size) + else: + parallel_config = None + + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + return DecodeOnlyEngineConfig(model_config=model_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + parallel_config=parallel_config) diff --git a/vllm/wde/decode_only/config.py b/vllm/wde/decode_only/config.py new file mode 100644 index 0000000000000..81216439d8ea7 --- /dev/null +++ b/vllm/wde/decode_only/config.py @@ -0,0 +1,89 @@ +from dataclasses import fields +from typing import Optional + +from vllm.logger import init_logger +from vllm.wde.core.config import EngineConfig, ModelConfig, SchedulerConfig +from vllm.wde.prefill_only.config import (PrefillOnlyParallelConfig, + PrefillOnlySchedulerConfig) + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class DecodeOnlyModelConfig(ModelConfig): + + def __init__(self, + output_last_hidden_states: bool = False, + enable_bidirectional: bool = None, + **kwargs) -> None: + super().__init__(**kwargs) + self.output_last_hidden_states = output_last_hidden_states + self.enable_bidirectional = enable_bidirectional + + self._verify_parameters() + + def _verify_parameters(self) -> None: + if self.enable_bidirectional is None: + if hasattr(self.hf_config, "enable_bidirectional"): + self.enable_bidirectional = self.hf_config.enable_bidirectional + else: + self.enable_bidirectional = False + + if self.enable_bidirectional: + self.output_last_hidden_states = True + + +class DecodeOnlySchedulerConfig(SchedulerConfig): + + def __init__(self): + self.output_last_hidden_states = False + + +class DecodeOnlyEmbeddingSchedulerConfig(DecodeOnlySchedulerConfig, + PrefillOnlySchedulerConfig): + + def __init__(self, + max_model_len: int, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None, + max_num_on_the_fly: Optional[int] = 3, + scheduling: str = "sync") -> None: + PrefillOnlySchedulerConfig.__init__(self, max_model_len, + max_num_batched_tokens, + max_num_requests, max_num_seqs, + max_num_on_the_fly, scheduling) + self.output_last_hidden_states = True + + +class DecodeOnlyEngineConfig(EngineConfig): + model_config: DecodeOnlyModelConfig + scheduler_config: DecodeOnlySchedulerConfig + parallel_config: Optional[PrefillOnlyParallelConfig] + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + if self.scheduler_config.output_last_hidden_states: + logger.info( + "Initializing an Encode Only engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "tokenizer_mode=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "device_config=%s, served_model_name=%s, " + "max_num_on_the_fly=%d, scheduling=%s)", VLLM_VERSION, + self.model_config.model, self.model_config.tokenizer, + self.model_config.tokenizer_mode, + self.model_config.trust_remote_code, self.model_config.dtype, + self.model_config.max_model_len, self.load_config.download_dir, + self.load_config.load_format, self.device_config.device, + self.model_config.served_model_name, + self.scheduler_config.max_num_on_the_fly, + self.scheduler_config.scheduling) diff --git a/vllm/wde/decode_only/modelzoo/__init__.py b/vllm/wde/decode_only/modelzoo/__init__.py new file mode 100644 index 0000000000000..35490fd13912d --- /dev/null +++ b/vllm/wde/decode_only/modelzoo/__init__.py @@ -0,0 +1,10 @@ +TASK = "decode_only" +WORKFLOW = "vllm.wde.decode_only.workflow:DecodeOnlyWorkflow" + +# Architecture -> (task, module, class, workflow). +DECODE_ONLY_MODELS = { + "LlamaForCausalLM": (TASK, "llama", "LlamaForCausalLM", WORKFLOW), + "Qwen2ForCausalLM": + (TASK, "qwen2", "Qwen2ForCausalLM", + "vllm.wde.retriever.modelzoo.gte_qwen.workflow:Qwen2Workflow"), +} diff --git a/vllm/wde/decode_only/modelzoo/llama.py b/vllm/wde/decode_only/modelzoo/llama.py new file mode 100644 index 0000000000000..4d8c98c04ea1c --- /dev/null +++ b/vllm/wde/decode_only/modelzoo/llama.py @@ -0,0 +1,613 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only LLaMA model compatible with HuggingFace weights.""" +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch import nn +from transformers import LlamaConfig + +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size) +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.interfaces import SupportsLoRA +from vllm.model_executor.models.utils import (PPMissingLayer, + is_pp_missing_parameter, + make_layers) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.utils import is_hip +from vllm.wde.core.layers.attention import (Attention, AttentionBackend, + AttentionMetadata) + + +class LlamaMLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + prefix: str = "", + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj") + self.down_proj = RowParallelLinear(input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj") + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class LlamaAttention(nn.Module): + + def __init__( + self, + config: LlamaConfig, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + attn_backend: AttentionBackend, + rope_theta: float = 10000, + rope_scaling: Optional[Dict[str, Any]] = None, + max_position_embeddings: int = 8192, + quant_config: Optional[QuantizationConfig] = None, + bias: bool = False, + cache_config: Optional[CacheConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + # MistralConfig has an optional head_dim introduced by Mistral-Nemo + self.head_dim = getattr(config, "head_dim", + self.hidden_size // self.total_num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + self.max_position_embeddings = max_position_embeddings + + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + + self.o_proj = RowParallelLinear( + input_size=self.total_num_heads * self.head_dim, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + + is_neox_style = True + if quant_config is not None and quant_config.get_name() == "gguf": + is_neox_style = False + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position_embeddings, + base=rope_theta, + rope_scaling=rope_scaling, + is_neox_style=is_neox_style, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_backend=attn_backend) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[List[torch.Tensor]], + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class LlamaDecoderLayer(nn.Module): + + def __init__( + self, + config: LlamaConfig, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + rope_theta = getattr(config, "rope_theta", 10000) + rope_scaling = getattr(config, "rope_scaling", None) + if rope_scaling is not None and getattr( + config, "original_max_position_embeddings", None): + rope_scaling["original_max_position_embeddings"] = ( + config.original_max_position_embeddings) + max_position_embeddings = getattr(config, "max_position_embeddings", + 8192) + # Support abacusai/Smaug-72B-v0.1 with attention_bias + # Support internlm/internlm-7b with bias + attention_bias = getattr(config, "attention_bias", False) or getattr( + config, "bias", False) + self.self_attn = LlamaAttention( + config=config, + attn_backend=attn_backend, + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + num_kv_heads=getattr(config, "num_key_value_heads", + config.num_attention_heads), + rope_theta=rope_theta, + rope_scaling=rope_scaling, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + bias=attention_bias, + cache_config=cache_config, + prefix=f"{prefix}.self_attn", + ) + self.mlp = LlamaMLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + bias=getattr(config, "mlp_bias", False), + prefix=f"{prefix}.mlp", + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[List[torch.Tensor]], + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class LlamaModel(nn.Module): + + def __init__( + self, + config: LlamaConfig, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + if get_pp_group().is_first_rank or (config.tie_word_embeddings + and get_pp_group().is_last_rank): + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + quant_config=quant_config, + ) + else: + self.embed_tokens = PPMissingLayer() + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: LlamaDecoderLayer(config=config, + attn_backend=attn_backend, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix), + prefix=f"{prefix}.layers") + if get_pp_group().is_last_rank: + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + else: + self.norm = PPMissingLayer() + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class LlamaForCausalLM(nn.Module, SupportsLoRA): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens", + "lm_head" + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + # Mistral/Llama models can also be loaded with --load-format mistral + # from consolidated.safetensors checkpoints + mistral_mapping = { + "layers": "model.layers", + "attention": "self_attn", + "wq": "q_proj", + "wk": "k_proj", + "wv": "v_proj", + "wo": "o_proj", + "attention_norm": "input_layernorm", + "feed_forward": "mlp", + "w1": "gate_proj", + "w2": "down_proj", + "w3": "up_proj", + "ffn_norm": "post_attention_layernorm", + "tok_embeddings": "model.embed_tokens", + "output": "lm_head", + "norm": "model.norm" + } + + def __init__( + self, + config: LlamaConfig, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = LlamaModel(config, + attn_backend, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + if get_pp_group().is_last_rank: + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + logit_scale) + self.sampler = Sampler() + else: + self.lm_head = PPMissingLayer() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + model_output = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return model_output + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + params_dict = dict(self.named_parameters()) + for name, loaded_weight in weights: + name, loaded_weight = self.maybe_remap_mistral(name, loaded_weight) + + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + # With tie_word_embeddings, we can skip lm_head.weight + # The weight might appear unnecessarily in the files if the model is + # processed with quantization, LoRA, fine-tuning, etc. + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # If this function is called, it should always initialize KV cache scale + # factors (or else raise an exception). Thus, handled exceptions should + # make sure to leave KV cache scale factors in a known good (dummy) state + def load_kv_cache_scales(self, quantization_param_path: str) -> None: + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + for layer_idx, scaling_factor in kv_cache_scales_loader( + quantization_param_path, tp_rank, tp_size, + self.config.num_hidden_layers, + self.config.__class__.model_type): + if not isinstance(self.model.layers[layer_idx], nn.Identity): + layer_self_attn = self.model.layers[layer_idx].self_attn + + if is_hip(): + # The scaling factor convention we are assuming is + # quantized_value * scaling_factor ~= true_value + # which is consistent with the practice of setting + # scaling_factor = tensor_amax / FPtype_max + scaling_factor *= 2 + if hasattr(layer_self_attn, "kv_scale"): + layer_self_attn.attn._kv_scale = scaling_factor + else: + raise RuntimeError("Self attention has no KV cache scaling " + "factor attribute!") + + # This function is used to remap the mistral format as + # used by Mistral and Llama <=2 + def maybe_remap_mistral( + self, name: str, + loaded_weight: torch.Tensor) -> Tuple[str, torch.Tensor]: + + def permute(w, n_heads): + attn_in = self.config.head_dim * n_heads + attn_out = self.config.hidden_size + + return w.view(n_heads, attn_in // n_heads // 2, 2, + attn_out).transpose(1, 2).reshape(attn_in, attn_out) + + mapping = self.mistral_mapping + modules = name.split(".") + + # rotary embeds should be sliced + if "wk" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_key_value_heads) + elif "wq" in modules: + loaded_weight = permute(loaded_weight, + self.config.num_attention_heads) + + for item in modules: + if item in mapping and mapping[item] not in name: + name = name.replace(item, mapping[item]) + + return name, loaded_weight diff --git a/vllm/wde/decode_only/modelzoo/qwen2.py b/vllm/wde/decode_only/modelzoo/qwen2.py new file mode 100644 index 0000000000000..b67cfa3335767 --- /dev/null +++ b/vllm/wde/decode_only/modelzoo/qwen2.py @@ -0,0 +1,421 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only Qwen2 model compatible with HuggingFace weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.config import CacheConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import (is_pp_missing_parameter, + make_layers) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors +from vllm.wde.core.layers.attention import (Attention, AttentionBackend, + AttentionMetadata) + + +class Qwen2MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.gate_up_proj = MergedColumnParallelLinear( + hidden_size, [intermediate_size] * 2, + bias=False, + quant_config=quant_config) + self.down_proj = RowParallelLinear(intermediate_size, + hidden_size, + bias=False, + quant_config=quant_config) + if hidden_act != "silu": + raise ValueError(f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now.") + self.act_fn = SiluAndMul() + + def forward(self, x): + gate_up, _ = self.gate_up_proj(x) + x = self.act_fn(gate_up) + x, _ = self.down_proj(x) + return x + + +class Qwen2Attention(nn.Module): + + def __init__(self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + attn_backend: AttentionBackend, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + rope_scaling: Optional[Tuple] = None) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim**-0.5 + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=True, + quant_config=quant_config, + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + ) + + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=self.rope_theta, + rope_scaling=rope_scaling, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + attn_backend=attn_backend) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, attn_metadata, kv_cache) + output, _ = self.o_proj(attn_output) + return output + + +class Qwen2DecoderLayer(nn.Module): + + def __init__( + self, + config: Qwen2Config, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 1000000) + rope_scaling = getattr(config, "rope_scaling", None) + self.self_attn = Qwen2Attention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + attn_backend=attn_backend, + cache_config=cache_config, + quant_config=quant_config, + rope_scaling=rope_scaling) + self.mlp = Qwen2MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + quant_config=quant_config, + ) + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + residual: Optional[torch.Tensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # Self Attention + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + + # Fully Connected + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + + +class Qwen2Model(nn.Module): + + def __init__( + self, + config: Qwen2Config, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + ) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Qwen2DecoderLayer(config=config, + attn_backend=attn_backend, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: + return self.embed_tokens(input_ids) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - + self.start_layer] if kv_caches is not None else None, + attn_metadata, + residual, + ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) + return hidden_states + + +class Qwen2ForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + def __init__( + self, + config: Qwen2Config, + attn_backend: AttentionBackend, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.quant_config = quant_config + self.model = Qwen2Model(config, attn_backend, cache_config, + quant_config) + + if config.tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = ParallelLMHead(config.vocab_size, + config.hidden_size, + quant_config=quant_config) + + self.logits_processor = LogitsProcessor(config.vocab_size) + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: Optional[List[torch.Tensor]] = None, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.named_parameters(remove_duplicate=False)) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if self.config.tie_word_embeddings and "lm_head.weight" in name: + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/wde/decode_only/processor/__init__.py b/vllm/wde/decode_only/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/decode_only/processor/output_processor.py b/vllm/wde/decode_only/processor/output_processor.py new file mode 100644 index 0000000000000..97eecab09b25a --- /dev/null +++ b/vllm/wde/decode_only/processor/output_processor.py @@ -0,0 +1,37 @@ +from typing import List + +import torch + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.output_processor import OutputProcessor +from vllm.wde.prefill_only.schema.engine_io import (PrefillOnlyRequestOutput, + PrefillOnlySchedulerOutput) + + +class DecodeOnlyHiddenStatesOutputProcessor(OutputProcessor): + + def __init__(self): + pass + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls() + + def __call__( + self, scheduler_output: PrefillOnlySchedulerOutput, + execute_output: torch.Tensor) -> List[PrefillOnlyRequestOutput]: + + request_outputs = [] + offset = 0 + for request in scheduler_output.requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + request_outputs.append( + PrefillOnlyRequestOutput( + request_id=request.request_id, + prompt_token_ids=prompt_token_ids, + finished=True, + # last pooling + outputs=execute_output[offset + n_tokens - 1])) + offset += n_tokens + return request_outputs diff --git a/vllm/wde/decode_only/workflow.py b/vllm/wde/decode_only/workflow.py new file mode 100644 index 0000000000000..e1d9298a6479d --- /dev/null +++ b/vllm/wde/decode_only/workflow.py @@ -0,0 +1,25 @@ +from vllm.wde.core.workflow import Workflow +from vllm.wde.prefill_only.workflow import PrefillOnlyWorkflow + + +class DecodeOnlyWorkflow(Workflow): + EngineArgs: str = "vllm.wde.decode_only.arg_utils:DecodeOnlyEngineArgs" + attn_type: str = "DECODER" + + @classmethod + def from_engine(cls, engine): + if engine.engine_config.model_config.output_last_hidden_states: + workflow = PrefillOnlyWorkflow.from_engine(engine) + + if engine.engine_config.model_config.enable_bidirectional: + workflow.attn_type = "ENCODER" + else: + workflow.attn_type = "DECODER" + + workflow.OutputProcessor = ( + "vllm.wde.decode_only.processor." + "output_processor:" + "DecodeOnlyHiddenStatesOutputProcessor") + return workflow + else: + raise ValueError("Not supported") diff --git a/vllm/wde/encode_only/__init__.py b/vllm/wde/encode_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/encode_only/arg_utils.py b/vllm/wde/encode_only/arg_utils.py new file mode 100644 index 0000000000000..729505b54d1a2 --- /dev/null +++ b/vllm/wde/encode_only/arg_utils.py @@ -0,0 +1,108 @@ +from dataclasses import dataclass +from typing import List, Optional, Union + +from vllm.logger import init_logger +from vllm.wde.core.arg_utils import EngineArgs +from vllm.wde.core.config import (DeviceConfig, LoadConfig, + filter_unexpected_fields) +from vllm.wde.encode_only.config import (EncodeOnlyEngineConfig, ModelConfig, + PrefillOnlyParallelConfig, + PrefillOnlySchedulerConfig) + +logger = init_logger(__name__) + + +def nullable_str(val: str): + if not val or val == "None": + return None + return val + + +@filter_unexpected_fields +@dataclass +class EncodeOnlyEngineArgs(EngineArgs): + """Arguments for vLLM engine.""" + model: str + served_model_name: Optional[Union[List[str]]] = None + tokenizer: Optional[str] = None + skip_tokenizer_init: bool = False + tokenizer_mode: str = 'auto' + trust_remote_code: bool = False + download_dir: Optional[str] = None + load_format: str = 'auto' + dtype: str = 'auto' + kv_cache_dtype: str = 'auto' + quantization_param_path: Optional[str] = None + disable_sliding_window: bool = False + seed: int = 0 + + max_model_len: Optional[int] = None + max_num_batched_tokens: Optional[int] = None + max_num_seqs: int = 256 + max_num_on_the_fly: int = 3 + scheduling: str = "async" + + data_parallel_size: int = 0 + + disable_log_stats: bool = False + revision: Optional[str] = None + code_revision: Optional[str] = None + rope_scaling: Optional[dict] = None + rope_theta: Optional[float] = None + tokenizer_revision: Optional[str] = None + quantization: Optional[str] = None + disable_custom_all_reduce: bool = False + device: str = 'auto' + model_loader_extra_config: Optional[dict] = None + ignore_patterns: Optional[Union[str, List[str]]] = None + + def __post_init__(self): + if self.tokenizer is None: + self.tokenizer = self.model + + def create_engine_config(self) -> EncodeOnlyEngineConfig: + device_config = DeviceConfig(device=self.device) + model_config = ModelConfig( + model=self.model, + tokenizer=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + dtype=self.dtype, + seed=self.seed, + revision=self.revision, + code_revision=self.code_revision, + rope_scaling=self.rope_scaling, + rope_theta=self.rope_theta, + tokenizer_revision=self.tokenizer_revision, + max_model_len=self.max_model_len, + quantization=self.quantization, + quantization_param_path=self.quantization_param_path, + disable_sliding_window=self.disable_sliding_window, + skip_tokenizer_init=self.skip_tokenizer_init, + served_model_name=self.served_model_name) + + scheduler_config = PrefillOnlySchedulerConfig( + max_num_batched_tokens=self.max_num_batched_tokens, + max_num_seqs=self.max_num_seqs, + max_model_len=model_config.max_model_len, + max_num_on_the_fly=self.max_num_on_the_fly, + scheduling=self.scheduling) + + load_config = LoadConfig( + load_format=self.load_format, + download_dir=self.download_dir, + model_loader_extra_config=self.model_loader_extra_config, + ignore_patterns=self.ignore_patterns, + ) + + if self.data_parallel_size > 0: + parallel_config = PrefillOnlyParallelConfig( + data_parallel_size=self.data_parallel_size) + else: + parallel_config = None + + return EncodeOnlyEngineConfig(model_config=model_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + parallel_config=parallel_config) diff --git a/vllm/wde/encode_only/config.py b/vllm/wde/encode_only/config.py new file mode 100644 index 0000000000000..eb11b2f53fb1d --- /dev/null +++ b/vllm/wde/encode_only/config.py @@ -0,0 +1,46 @@ +from dataclasses import dataclass, fields +from typing import Optional + +from vllm.logger import init_logger +from vllm.wde.core.config import EngineConfig, ModelConfig +from vllm.wde.prefill_only.config import (PrefillOnlyParallelConfig, + PrefillOnlySchedulerConfig) + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +@dataclass(frozen=True) +class EncodeOnlyEngineConfig(EngineConfig): + model_config: ModelConfig + scheduler_config: PrefillOnlySchedulerConfig + parallel_config: Optional[PrefillOnlyParallelConfig] + + def to_dict(self): + """Return the configs as a dictionary, for use in **kwargs. + """ + return dict( + (field.name, getattr(self, field.name)) for field in fields(self)) + + def log_config(self): + from vllm.version import __version__ as VLLM_VERSION + logger.info( + "Initializing an Encode Only engine (v%s) with config: " + "model=%r, tokenizer=%r, " + "tokenizer_mode=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, " + "device_config=%s, served_model_name=%s, " + "max_num_on_the_fly=%d, scheduling=%s)", VLLM_VERSION, + self.model_config.model, self.model_config.tokenizer, + self.model_config.tokenizer_mode, + self.model_config.trust_remote_code, self.model_config.dtype, + self.model_config.max_model_len, self.load_config.download_dir, + self.load_config.load_format, self.device_config.device, + self.model_config.served_model_name, + self.scheduler_config.max_num_on_the_fly, + self.scheduler_config.scheduling) + if self.parallel_config is not None: + logger.info("Parallel config: data_parallel_size=%d", + self.parallel_config.data_parallel_size) diff --git a/vllm/wde/encode_only/modelzoo/__init__.py b/vllm/wde/encode_only/modelzoo/__init__.py new file mode 100644 index 0000000000000..714b53bd306f3 --- /dev/null +++ b/vllm/wde/encode_only/modelzoo/__init__.py @@ -0,0 +1,9 @@ +TASK = "encode_only" +WORKFLOW = "vllm.wde.encode_only.workflow:EncodeOnlyWorkflow" + +# Architecture -> (task, module, class, workflow). +ENCODE_ONLY_MODELS = { + "XLMRobertaForMaskedLM": + (TASK, "xlm_roberta", "XLMRobertaForMaskedLM", WORKFLOW), + "BertForMaskedLM": (TASK, "bert", "BertForMaskedLM", WORKFLOW), +} diff --git a/vllm/wde/encode_only/modelzoo/bert.py b/vllm/wde/encode_only/modelzoo/bert.py new file mode 100644 index 0000000000000..ca57f0833cd17 --- /dev/null +++ b/vllm/wde/encode_only/modelzoo/bert.py @@ -0,0 +1,404 @@ +# Derived from Bert implementation posted on HuggingFace; license below: +# coding=utf-8 +# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team. # noqa: E501 +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch BERT model.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import BertConfig +from transformers.utils import logging + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.wde.core.layers.attention import (Attention, AttentionBackend, + AttentionMetadata) + +logger = logging.get_logger(__name__) + + +class LoadWeightsMixin: + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v") + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if hasattr(self, "prefix"): + name = self.prefix + name + + if name in self._ignore_weights_keys: + continue + + if name == "bert.embeddings.token_type_embeddings.weight": + # token_type_ids is all zero, + # so we only need token_type_embeddings[0] + self.bert.embeddings.init_token_type_embeddings0() + default_weight_loader( + self.bert.embeddings.token_type_embeddings0, + loaded_weight[0]) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # https://huggingface.co/google-bert/bert-base-uncased/discussions/70 + # https://github.com/huggingface/transformers/blob/fee86516a48c92133847fc7b44ca2f83c7c5634d/src/transformers/modeling_utils.py#L691-L720 + if "LayerNorm.gamma" in name: + name = name.replace("LayerNorm.gamma", "LayerNorm.weight") + if "LayerNorm.beta" in name: + name = name.replace("LayerNorm.beta", "LayerNorm.bias") + + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + +class BertEmbeddings(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.config = config + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + assert self.position_embedding_type == "absolute" + + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.token_type_embeddings0 = None + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + padding_idx=config.pad_token_id) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def init_token_type_embeddings0(self): + del self.token_type_embeddings0 + self.register_buffer( + "token_type_embeddings0", + torch.zeros(self.config.hidden_size, + dtype=self.word_embeddings.weight.dtype, + device=self.word_embeddings.weight.device)) + + def forward(self, input_ids, positions): + embeddings = self.word_embeddings(input_ids) + if self.token_type_embeddings0 is not None: + token_type_embeddings = self.token_type_embeddings0 + embeddings += token_type_embeddings + + embeddings += self.position_embeddings(positions) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class BertSelfAttention(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + num_kv_heads = config.num_attention_heads + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = hidden_size // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + num_heads, + num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + attn_backend=attn_backend) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v, attn_metadata) + return attn_output + + +class BertSelfOutput(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertAttention(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.self = BertSelfAttention(config, attn_backend) + self.output = BertSelfOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + self_outputs = self.self(hidden_states, attn_metadata) + attention_output = self.output(self_outputs, hidden_states) + return attention_output + + +class BertIntermediate(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class BertOutput(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class BertLayer(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.attention = BertAttention(config, attn_backend, quant_config) + self.intermediate = BertIntermediate(config, quant_config) + self.output = BertOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + attention_output = self.attention(hidden_states, attn_metadata) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class BertEncoder(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.layer = nn.ModuleList([ + BertLayer(config, attn_backend, quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attn_metadata) + return hidden_states + + +class BertPooler(nn.Module): + + def __init__(self, + config: BertConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.activation = nn.Tanh() + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + seq_start_loc = attn_metadata.seq_start_loc + first_token_tensor = hidden_states[seq_start_loc[:-1]] + pooled_output, _ = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +class BertModel(nn.Module): + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + add_pooling_layer: bool = True, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.embeddings = BertEmbeddings(config) + self.encoder = BertEncoder(config, attn_backend, quant_config) + self.pooler = BertPooler(config) if add_pooling_layer else None + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor]: + embedding_output = self.embeddings( + input_ids=input_ids, + positions=positions, + ) + sequence_output = self.encoder(embedding_output, attn_metadata) + pooled_output = self.pooler( + sequence_output, + attn_metadata) if self.pooler is not None else None + return sequence_output, pooled_output + + +class BertForMaskedLM(nn.Module, LoadWeightsMixin): + _ignore_weights_keys = [ + "cls.predictions.transform.LayerNorm.gamma", + "cls.predictions.transform.dense.weight", + "cls.seq_relationship.weight", + ] + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.bert = BertModel(config, + attn_backend, + quant_config=quant_config, + add_pooling_layer=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> Tuple[torch.Tensor]: + + sequence_output, pooled_output = self.bert( + input_ids, + positions, + attn_metadata, + ) + + return sequence_output, pooled_output diff --git a/vllm/wde/encode_only/modelzoo/xlm_roberta.py b/vllm/wde/encode_only/modelzoo/xlm_roberta.py new file mode 100644 index 0000000000000..f64fb9904e784 --- /dev/null +++ b/vllm/wde/encode_only/modelzoo/xlm_roberta.py @@ -0,0 +1,469 @@ +# Derived from XLM-RoBERTa implementation posted on HuggingFace; license below: +# coding=utf-8 +# Copyright 2019 Facebook AI Research and the HuggingFace Inc. team. +# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""PyTorch XLM-RoBERTa model.""" + +from typing import Iterable, Optional, Tuple + +import torch +from torch import nn +from transformers import XLMRobertaConfig +from transformers.utils import logging + +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.utils import is_pp_missing_parameter +from vllm.wde.core.layers.attention import (Attention, AttentionBackend, + AttentionMetadata) + +logger = logging.get_logger(__name__) + + +class LoadWeightsMixin: + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "query", "q"), + ("qkv_proj", "key", "k"), + ("qkv_proj", "value", "v") + ] + + params_dict = dict(self.named_parameters(remove_duplicate=False)) + + for name, loaded_weight in weights: + if hasattr(self, "prefix"): + name = self.prefix + name + + if name in self._ignore_weights_keys: + continue + + if name == "roberta.embeddings.token_type_embeddings.weight": + # token_type_ids is all zero, + # so we only need token_type_embeddings[0] + self.roberta.embeddings.init_token_type_embeddings0() + default_weight_loader( + self.roberta.embeddings.token_type_embeddings0, + loaded_weight[0]) + continue + + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + + name = name.replace(weight_name, param_name) + + if name.endswith(".bias") and name not in params_dict: + continue + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + if is_pp_missing_parameter(name, self): + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + if hasattr(self, "tie_weights"): + self.tie_weights() + + +class XLMRobertaEmbeddings(nn.Module): + + def __init__(self, config: XLMRobertaConfig): + super().__init__() + self.config = config + self.position_embedding_type = getattr(config, + "position_embedding_type", + "absolute") + assert self.position_embedding_type == "absolute" + + self.word_embeddings = nn.Embedding(config.vocab_size, + config.hidden_size, + padding_idx=config.pad_token_id) + self.token_type_embeddings0 = None + self.position_embeddings = nn.Embedding( + config.max_position_embeddings, + config.hidden_size, + padding_idx=config.pad_token_id) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def init_token_type_embeddings0(self): + del self.token_type_embeddings0 + self.register_buffer( + "token_type_embeddings0", + torch.zeros(self.config.hidden_size, + dtype=self.word_embeddings.weight.dtype, + device=self.word_embeddings.weight.device)) + + def forward(self, input_ids, positions): + embeddings = self.word_embeddings(input_ids) + + # token_type_embeddings is all zero in FacebookAI/xlm-roberta, + # so we don't need it. + # token_type_ids is all zero in BGEM3, + # so we only need token_type_embeddings[0] + if self.token_type_embeddings0 is not None: + token_type_embeddings = self.token_type_embeddings0 + embeddings += token_type_embeddings + + embeddings += self.position_embeddings(positions) + embeddings = self.LayerNorm(embeddings) + return embeddings + + +class XLMRobertaSelfAttention(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + hidden_size = config.hidden_size + num_heads = config.num_attention_heads + num_kv_heads = config.num_attention_heads + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = hidden_size // num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + + self.scaling = self.head_dim**-0.5 + + self.qkv_proj = QKVParallelLinear( + config.hidden_size, + self.head_dim, + num_heads, + num_kv_heads, + bias=True, + quant_config=quant_config, + ) + + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + quant_config=quant_config, + attn_backend=attn_backend) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + attn_output = self.attn(q, k, v, attn_metadata) + return attn_output + + +class XLMRobertaSelfOutput(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class XLMRobertaAttention(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.self = XLMRobertaSelfAttention(config, attn_backend) + self.output = XLMRobertaSelfOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + self_outputs = self.self(hidden_states, attn_metadata) + attention_output = self.output(self_outputs, hidden_states) + return attention_output + + +class XLMRobertaIntermediate(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.hidden_size, + config.intermediate_size, + bias=True, + quant_config=quant_config) + self.intermediate_act_fn = get_act_fn(config.hidden_act) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.intermediate_act_fn(hidden_states) + return hidden_states + + +class XLMRobertaOutput(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = RowParallelLinear(config.intermediate_size, + config.hidden_size, + bias=True, + quant_config=quant_config) + self.LayerNorm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + def forward(self, hidden_states: torch.Tensor, + input_tensor: torch.Tensor) -> torch.Tensor: + hidden_states, _ = self.dense(hidden_states) + hidden_states = self.LayerNorm(hidden_states + input_tensor) + return hidden_states + + +class XLMRobertaLayer(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.attention = XLMRobertaAttention(config, attn_backend, + quant_config) + self.intermediate = XLMRobertaIntermediate(config, quant_config) + self.output = XLMRobertaOutput(config, quant_config) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + attention_output = self.attention(hidden_states, attn_metadata) + intermediate_output = self.intermediate(attention_output) + layer_output = self.output(intermediate_output, attention_output) + return layer_output + + +class XLMRobertaEncoder(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.layer = nn.ModuleList([ + XLMRobertaLayer(config, attn_backend, quant_config) + for _ in range(config.num_hidden_layers) + ]) + + def forward( + self, + hidden_states: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + for i, layer_module in enumerate(self.layer): + hidden_states = layer_module(hidden_states, attn_metadata) + return hidden_states + + +class XLMRobertaModel(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.config = config + self.embeddings = XLMRobertaEmbeddings(config) + self.encoder = XLMRobertaEncoder(config, attn_backend, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + positions += self.config.pad_token_id + 1 + + embedding_output = self.embeddings( + input_ids=input_ids, + positions=positions, + ) + + encoder_outputs = self.encoder(embedding_output, attn_metadata) + + return encoder_outputs + + +class XLMRobertaLMHead(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config) + self.layer_norm = nn.LayerNorm(config.hidden_size, + eps=config.layer_norm_eps) + + self.decoder = ColumnParallelLinear(config.hidden_size, + config.vocab_size, + quant_config=quant_config) + self.gelu = get_act_fn("gelu") + + def forward(self, features): + x, _ = self.dense(features) + x = self.gelu(x) + x = self.layer_norm(x) + x, _ = self.decoder(x) + return x + + +class XLMRobertaForMaskedLM(nn.Module, LoadWeightsMixin): + _ignore_weights_keys = [ + "roberta.pooler.dense.weight", + "roberta.pooler.dense.bias", + # token_type_embeddings is all zero + "roberta.embeddings.token_type_embeddings.weight" + ] + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + + self.roberta = XLMRobertaModel(config, attn_backend, quant_config) + self.lm_head = XLMRobertaLMHead(config, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + sequence_output = self.roberta( + input_ids, + positions, + attn_metadata, + ) + logits = self.lm_head(sequence_output) + return logits + + def tie_weights(self): + self.lm_head.decoder.weight = ( + self.roberta.embeddings.word_embeddings.weight) + self.lm_head.decoder.bias.zero_() + + +class XLMRobertaClassificationHead(nn.Module): + + def __init__(self, + config: XLMRobertaConfig, + quant_config: Optional[QuantizationConfig] = None): + super().__init__() + self.dense = ColumnParallelLinear(config.hidden_size, + config.hidden_size, + quant_config=quant_config) + self.out_proj = ColumnParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config) + + def forward(self, features: torch.Tensor) -> torch.Tensor: + x, _ = self.dense(features) + x = torch.tanh(x) + x, _ = self.out_proj(x) + return x + + +class XLMRobertaForSequenceClassification(nn.Module, LoadWeightsMixin): + _ignore_weights_keys = [ + "roberta.pooler.dense.weight", "roberta.pooler.dense.bias" + ] + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + self.num_labels = config.num_labels + + self.roberta = XLMRobertaModel(config, attn_backend, quant_config) + self.classifier = XLMRobertaClassificationHead(config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + + sequence_output = self.roberta( + input_ids, + positions, + attn_metadata, + ) + + seq_start_loc = attn_metadata.seq_start_loc + + # take token (equiv. to [CLS]) + cls_features = sequence_output[seq_start_loc[:-1]] + + logits = self.classifier(cls_features) + return logits diff --git a/vllm/wde/encode_only/workflow.py b/vllm/wde/encode_only/workflow.py new file mode 100644 index 0000000000000..e1abba4600c1e --- /dev/null +++ b/vllm/wde/encode_only/workflow.py @@ -0,0 +1,6 @@ +from vllm.wde.prefill_only.workflow import PrefillOnlyWorkflow + + +class EncodeOnlyWorkflow(PrefillOnlyWorkflow): + EngineArgs: str = "vllm.wde.encode_only.arg_utils:EncodeOnlyEngineArgs" + attn_type: str = "ENCODER" diff --git a/vllm/wde/entrypoints/__init__.py b/vllm/wde/entrypoints/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/entrypoints/llm.py b/vllm/wde/entrypoints/llm.py new file mode 100644 index 0000000000000..6c2874f36ceb1 --- /dev/null +++ b/vllm/wde/entrypoints/llm.py @@ -0,0 +1,180 @@ +from typing import List, Optional, Sequence, Union, cast + +from tqdm import tqdm +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast + +from vllm.logger import init_logger +from vllm.utils import Counter +from vllm.wde.core.inputs.tokenizer import get_cached_tokenizer +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.schema.engine_io import Params, RequestOutput +from vllm.wde.core.schema.engine_io import TextOnlyInputs as PromptInputs +from vllm.wde.reranker.schema.engine_io import RerankerInputs + +logger = init_logger(__name__) + + +class LLM: + + def __init__( + self, + model: str, + tokenizer: Optional[str] = None, + tokenizer_mode: str = "auto", + skip_tokenizer_init: bool = False, + trust_remote_code: bool = False, + tensor_parallel_size: int = 1, + dtype: str = "auto", + quantization: Optional[str] = None, + revision: Optional[str] = None, + tokenizer_revision: Optional[str] = None, + seed: int = 0, + gpu_memory_utilization: float = 0.9, + swap_space: int = 4, + cpu_offload_gb: float = 0, + enforce_eager: bool = False, + max_context_len_to_capture: Optional[int] = None, + max_seq_len_to_capture: int = 8192, + disable_custom_all_reduce: bool = False, + **kwargs, + ) -> None: + if "disable_log_stats" not in kwargs: + kwargs["disable_log_stats"] = True + removed_vision_keys = ("image_token_id", "image_feature_size", + "image_input_shape", "image_input_type") + if any(k in kwargs for k in removed_vision_keys): + raise TypeError( + "There is no need to pass vision-related arguments anymore.") + engine_args = dict( + model=model, + tokenizer=tokenizer, + tokenizer_mode=tokenizer_mode, + skip_tokenizer_init=skip_tokenizer_init, + trust_remote_code=trust_remote_code, + tensor_parallel_size=tensor_parallel_size, + dtype=dtype, + quantization=quantization, + revision=revision, + tokenizer_revision=tokenizer_revision, + seed=seed, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=swap_space, + cpu_offload_gb=cpu_offload_gb, + enforce_eager=enforce_eager, + max_context_len_to_capture=max_context_len_to_capture, + max_seq_len_to_capture=max_seq_len_to_capture, + disable_custom_all_reduce=disable_custom_all_reduce, + **kwargs, + ) + self.llm_engine = LLMEngine.from_engine_args(engine_args) + self.request_counter = Counter() + + def get_tokenizer( + self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + return self.llm_engine.tokenizer.tokenizer + + def set_tokenizer( + self, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + ) -> None: + # While CachedTokenizer is dynamic, have no choice but + # compare class name. Misjudgment will arise from + # user-defined tokenizer started with 'Cached' + if tokenizer.__class__.__name__.startswith("Cached"): + self.llm_engine.tokenizer.tokenizer = tokenizer + else: + self.llm_engine.tokenizer.tokenizer = get_cached_tokenizer( + tokenizer) + + def encode( + self, + inputs: Union[Union[PromptInputs, Sequence[PromptInputs]], + Optional[Union[str, List[str]]]] = None, + pooling_params: Optional[Union[Params, Sequence[Params]]] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], inputs) + + if pooling_params is None: + # Use default pooling params. + pooling_params = Params() + + self._validate_and_add_requests( + inputs=inputs, + params=pooling_params, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) + + def reranker( + self, + inputs: RerankerInputs, + params: Optional[Union[Params, Sequence[Params]]] = None, + use_tqdm: bool = True, + ) -> List[RequestOutput]: + inputs = cast(Union[PromptInputs, Sequence[PromptInputs]], inputs) + + for i, request_inputs in enumerate(inputs): + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return LLMEngine.validate_outputs(outputs, RequestOutput) + + def _validate_and_add_requests( + self, + inputs: Union[PromptInputs, Sequence[PromptInputs]], + params: Optional[Union[Params, Sequence[Params]]] = None, + ) -> None: + if isinstance(inputs, (str, dict)): + # Convert a single prompt to a list. + inputs = [inputs] + + num_requests = len(inputs) + + if isinstance(params, list) and len(params) != num_requests: + raise ValueError("The lengths of prompts and params " + "must be the same.") + + # Add requests to the engine. + for i, request_inputs in enumerate(inputs): + self._add_request( + request_inputs, + params[i] if isinstance(params, Sequence) else params) + + def _add_request( + self, + inputs: PromptInputs, + params: Params, + ) -> None: + request_id = str(next(self.request_counter)) + self.llm_engine.add_request(request_id, inputs, params) + + def _run_engine(self, *, use_tqdm: bool) -> List[RequestOutput]: + # Initialize tqdm. + if use_tqdm: + num_requests = self.llm_engine.get_num_unfinished_requests() + pbar = tqdm( + total=num_requests, + desc="Processed prompts", + dynamic_ncols=True, + postfix=(f"est. speed input: {0:.2f} toks/s, " + f"output: {0:.2f} toks/s"), + ) + # Run the engine. + outputs: List[RequestOutput] = [] + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() + for output in step_outputs: + if output.finished: + outputs.append(output) + if use_tqdm: + pbar.update(1) + if use_tqdm: + pbar.close() + # Sort the outputs by request ID. + # This is necessary because some requests may be finished earlier than + # its previous requests. + return sorted(outputs, key=lambda x: int(x.request_id)) diff --git a/vllm/wde/prefill_only/__init__.py b/vllm/wde/prefill_only/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/config.py b/vllm/wde/prefill_only/config.py new file mode 100644 index 0000000000000..1cebc17671dc0 --- /dev/null +++ b/vllm/wde/prefill_only/config.py @@ -0,0 +1,68 @@ +from typing import Optional + +from vllm.logger import init_logger +from vllm.wde.core.config import ParallelConfig, SchedulerConfig + +logger = init_logger(__name__) + +_GB = 1 << 30 + + +class PrefillOnlySchedulerConfig(SchedulerConfig): + + def __init__(self, + max_model_len: int, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None, + max_num_on_the_fly: Optional[int] = 3, + scheduling: str = "sync") -> None: + self.max_model_len = max_model_len + self.max_num_requests: int = 0 + self.max_num_batched_tokens: int = 0 + self.max_num_on_the_fly: int = max_num_on_the_fly + self.scheduling = scheduling + + self.set_args(max_num_batched_tokens, max_num_requests, max_num_seqs) + + def set_args(self, + max_num_batched_tokens: Optional[int] = None, + max_num_requests: Optional[int] = None, + max_num_seqs: Optional[int] = None): + if max_num_seqs is not None: + self.max_num_requests = max_num_seqs + else: + self.max_num_requests = max_num_requests + + if max_num_batched_tokens is not None: + self.max_num_batched_tokens = max_num_batched_tokens + else: + self.max_num_batched_tokens = (self.max_model_len * + self.max_num_requests) + + self._verify_args() + + def _verify_args(self) -> None: + if self.max_num_batched_tokens < self.max_model_len: + raise ValueError( + f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " + "be greater than or equal to max_model_len " + f"({self.max_model_len}).") + + if self.max_num_on_the_fly < 2: + raise ValueError( + f"max_num_on_the_fly {self.max_num_on_the_fly} must " + "be greater than 1") + + if self.scheduling not in ["sync", "async", "double_buffer"]: + raise ValueError(f"scheduling {self.scheduling} must " + f"in sync, async double_buffer") + + +class PrefillOnlyParallelConfig(ParallelConfig): + + def __init__( + self, + data_parallel_size: int, + ): + self.data_parallel_size = data_parallel_size diff --git a/vllm/wde/prefill_only/executor/__init__.py b/vllm/wde/prefill_only/executor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/executor/gpu_data_parallelism_executor.py b/vllm/wde/prefill_only/executor/gpu_data_parallelism_executor.py new file mode 100644 index 0000000000000..aaed487a5b395 --- /dev/null +++ b/vllm/wde/prefill_only/executor/gpu_data_parallelism_executor.py @@ -0,0 +1,80 @@ +import atexit +from queue import Queue +from threading import Thread +from typing import List, Optional + +from vllm.logger import init_logger +from vllm.wde.core.config import EngineConfig +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.worker import WorkerBase, create_worker +from vllm.wde.core.workflow import Workflow +from vllm.wde.prefill_only.executor.gpu_executor import ( + double_buffer_execute_loop, simple_execute_loop) + +logger = init_logger(__name__) + + +class GPUDataParallelismExecutor: + support_scheduling = ["async_scheduling"] + + def __init__(self, engine_config: EngineConfig, workflow: Workflow, + attn_backend: AttentionBackend, executor_in: Queue, + executor_out: Queue) -> None: + self.engine_config = engine_config + self.workflow = workflow + self.attn_backend = attn_backend + self.output_to_cpu = False + + self.executor_in = executor_in + self.executor_out = executor_out + + self.workers: Optional[List[WorkerBase]] = None + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend, + executor_in=engine.executor_in, + executor_out=engine.executor_out) + + def thread_target(self, rank: int): + # Is there a better way to avoid loading the model multiple times? + # Load to cpu first? + worker_kwargs = dict(engine_config=self.engine_config, + attn_backend=self.attn_backend, + envs={'CUDA_VISIBLE_DEVICES': str(rank)}) + worker_kwargs.update(module=self.workflow.Worker) + worker = create_worker(**worker_kwargs) + worker.init_device() + worker.load_model() + + if self.engine_config.scheduler_config.scheduling == "double_buffer": + execute_loop = double_buffer_execute_loop + else: + execute_loop = simple_execute_loop + + execute_loop(worker, self.executor_in, self.executor_out, + self.output_to_cpu) + + def ensure_start_execute_loop(self): + if self.workers is None: + self.workers = [] + for rank in range( + self.engine_config.parallel_config.data_parallel_size): + worker = Thread(target=self.thread_target, + args=(rank, ), + daemon=True) + worker.start() + self.workers.append(worker) + atexit.register(self.shutdown_execute_loop) + + def shutdown_execute_loop(self): + if self.workers is not None: + for worker in self.workers: + self.executor_in.put(None) + for worker in self.workers: + worker.join() + self.workers = None + atexit.unregister(self.shutdown_execute_loop) diff --git a/vllm/wde/prefill_only/executor/gpu_executor.py b/vllm/wde/prefill_only/executor/gpu_executor.py new file mode 100644 index 0000000000000..55f55a9d3703d --- /dev/null +++ b/vllm/wde/prefill_only/executor/gpu_executor.py @@ -0,0 +1,208 @@ +import atexit +import queue +from queue import Queue +from threading import Thread +from typing import Optional + +import torch + +from vllm.logger import init_logger +from vllm.wde.core.config import EngineConfig +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.schema.execute_io import ExecuteInput, ExecuteOutput +from vllm.wde.core.worker import WorkerBase, create_worker +from vllm.wde.core.workflow import Workflow + +logger = init_logger(__name__) + + +class GPUExecutor: + support_scheduling = ["sync_scheduling"] + + def __init__( + self, + engine_config: EngineConfig, + workflow: Workflow, + attn_backend: AttentionBackend, + ) -> None: + self.engine_config = engine_config + self.workflow = workflow + self.attn_backend = attn_backend + self.output_to_cpu = False + self._init_executor() + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend) + + def _init_executor(self) -> None: + """Initialize the worker and load the model. + """ + + worker_kwargs = dict( + engine_config=self.engine_config, + attn_backend=self.attn_backend, + ) + worker_kwargs.update(module=self.workflow.Worker) + + self.worker = create_worker(**worker_kwargs) + self.worker.init_device() + self.worker.load_model() + + def execute_model(self, + executor_input: ExecuteInput) -> Optional[ExecuteOutput]: + executor_input.model_input.to(self.worker.device) + output = self.worker(executor_input) + if self.output_to_cpu: + output.to("cpu") + return output + + def shutdown_execute_loop(self): + pass + + +class GPUAsyncExecutor(GPUExecutor): + support_scheduling = ["async_scheduling"] + + def __init__(self, engine_config: EngineConfig, workflow: Workflow, + attn_backend: AttentionBackend, executor_in: Queue, + executor_out: Queue) -> None: + super().__init__(engine_config, workflow, attn_backend) + self.executor_in = executor_in + self.executor_out = executor_out + + self.executor_thread: Optional[Thread] = None + + if self.engine_config.scheduler_config.scheduling == "double_buffer": + self.execute_loop = double_buffer_execute_loop + else: + self.execute_loop = simple_execute_loop + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine_config=engine.engine_config, + workflow=engine.workflow, + attn_backend=engine.attn_backend, + executor_in=engine.executor_in, + executor_out=engine.executor_out) + + def ensure_start_execute_loop(self): + if self.executor_thread is None or not self.executor_thread.is_alive(): + self.executor_thread = Thread(target=self.execute_loop, + args=(self.worker, self.executor_in, + self.executor_out, + self.output_to_cpu), + daemon=True) + self.executor_thread.start() + atexit.register(self.shutdown_execute_loop) + + def shutdown_execute_loop(self): + if self.executor_thread.is_alive(): + self.executor_in.put(None) + self.executor_thread.join() + atexit.unregister(self.shutdown_execute_loop) + + +def simple_execute_loop(worker: WorkerBase, + executor_in: Queue, + executor_out: Queue, + output_to_cpu: bool = False): + + def execute_model(executor_input: ExecuteInput) -> Optional[ExecuteOutput]: + executor_input.model_input.to(worker.device) + output = worker(executor_input) + if output_to_cpu: + output.to("cpu") + return output + + while True: + o = executor_in.get() + if o is None: + break + + scheduler_output, executor_input = o + executor_output = execute_model(executor_input) + if output_to_cpu: + executor_output.to("cpu") + executor_out.put((scheduler_output, executor_output)) + + +def double_buffer_execute_loop(worker: WorkerBase, + executor_in: Queue, + executor_out: Queue, + output_to_cpu: bool = False): + from dataclasses import dataclass + + from vllm.wde.core.schema.engine_io import SchedulerOutput + + @dataclass + class Task: + scheduler_output: SchedulerOutput + executor_input: ExecuteInput + executor_output: Optional[ExecuteOutput] + + @classmethod + def get(cls, block): + o = executor_in.get(block) + if o is None: + return None + + scheduler_output, executor_input = o + + task = cls(scheduler_output=scheduler_output, + executor_input=executor_input, + executor_output=None) + return task + + current_task: Optional[Task] = None + next_task: Optional[Task] = None + compute_stream = torch.cuda.Stream() + io_stream = torch.cuda.Stream() + + go_on = True + while go_on: + if current_task is None: + current_task = Task.get(block=True) + if current_task is None: + break + + with torch.cuda.stream(compute_stream): + current_task.executor_input.model_input.to(worker.device, + non_blocking=True) + current_task.executor_output = worker( + current_task.executor_input) + end_compute = torch.cuda.Event() + else: + with torch.cuda.stream(compute_stream): + end_compute = torch.cuda.Event() + + try: + next_task = Task.get(block=False) + if next_task is None: + go_on = False + else: + with torch.cuda.stream(io_stream): + next_task.executor_input.model_input.to(worker.device, + non_blocking=True) + + compute_stream.wait_stream(io_stream) + + with torch.cuda.stream(compute_stream): + next_task.executor_output = worker( + next_task.executor_input) + except queue.Empty: + pass + + end_compute.wait() + if output_to_cpu: + with torch.cuda.stream(io_stream): + current_task.executor_output.to("cpu", non_blocking=True) + io_stream.synchronize() + executor_out.put( + (current_task.scheduler_output, current_task.executor_output)) + + current_task = next_task + next_task = None diff --git a/vllm/wde/prefill_only/layers/__init__.py b/vllm/wde/prefill_only/layers/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/layers/attention/__init__.py b/vllm/wde/prefill_only/layers/attention/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/__init__.py @@ -0,0 +1 @@ + diff --git a/vllm/wde/prefill_only/layers/attention/backends/__init__.py b/vllm/wde/prefill_only/layers/attention/backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/layers/attention/backends/abstract.py b/vllm/wde/prefill_only/layers/attention/backends/abstract.py new file mode 100644 index 0000000000000..7080ceb753008 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/abstract.py @@ -0,0 +1,59 @@ +from abc import ABC +from dataclasses import dataclass +from typing import Generic, List, Optional, TypeVar + +import torch + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.layers.attention.abstract import (AttentionBackend, + AttentionImpl, + AttentionMetadata, + AttentionMetadataBuilder) + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyAttentionBackend(AttentionBackend, ABC): + pass + + +class PrefillOnlyAttentionImpl(AttentionImpl, ABC): + pass + + +@dataclass +class PrefillOnlyAttentionMetadata(AttentionMetadata): + max_seq_len: int + seq_lens: list[int] + + # (batch_size + 1,). The cumulative sequence lengths of the sequences in + # the batch, used to index into sequence. E.g., if the sequence length is + # [4, 6], it is [0, 4, 10]. + seq_start_loc: Optional[torch.Tensor] + + +T = TypeVar("T", bound=AttentionMetadata) + + +class PrefillOnlyAttentionMetadataBuilder(AttentionMetadataBuilder, + Generic[T]): + + def __init__(self): + pass + + def __call__(self, seq_lens: List[int]): + seq_lens_tensor = torch.tensor(seq_lens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + seq_start_loc = torch.zeros(seq_lens_tensor.shape[0] + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(seq_lens_tensor, + dim=0, + dtype=seq_start_loc.dtype, + out=seq_start_loc[1:]) + + return PrefillOnlyAttentionMetadata(seq_lens=seq_lens, + max_seq_len=max(seq_lens), + seq_start_loc=seq_start_loc) diff --git a/vllm/wde/prefill_only/layers/attention/backends/flash_attn.py b/vllm/wde/prefill_only/layers/attention/backends/flash_attn.py new file mode 100644 index 0000000000000..6b4305f7afa1c --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/flash_attn.py @@ -0,0 +1,154 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + + +class PrefillOnlyFlashAttentionBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flash_attn" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashAttentionImpl"]: + return PrefillOnlyFlashAttentionImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyFlashAttentionMetadata"]: + return PrefillOnlyFlashAttentionMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyFlashAttentionMetadataBuilder + + +@dataclass +class PrefillOnlyFlashAttentionMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyFlashAttentionMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyFlashAttentionMetadata] +): + pass + + +class PrefillOnlyFlashAttentionImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "FlashAttention does not support block-sparse attention.") + + from vllm_flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + PrefillOnlyFlashAttentionBackend.get_supported_head_sizes()) + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyFlashAttentionMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyFlashAttentionImpl") + + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + attn_output = self.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # Reshape the output tensor. + return attn_output.view(num_tokens, hidden_size) diff --git a/vllm/wde/prefill_only/layers/attention/backends/flashinfer.py b/vllm/wde/prefill_only/layers/attention/backends/flashinfer.py new file mode 100644 index 0000000000000..6dfe56a03fa69 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/flashinfer.py @@ -0,0 +1,156 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + + +class PrefillOnlyFlashInferBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [32, 64, 96, 128, 160, 192, 224, 256] + + @staticmethod + def get_name() -> str: + return "flashinfer" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyFlashInferImpl"]: + return PrefillOnlyFlashInferImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: + return PrefillOnlyFlashInferMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyFlashInferMetadataBuilder"]: + return PrefillOnlyFlashInferMetadataBuilder + + +@dataclass +class PrefillOnlyFlashInferMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyFlashInferMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyFlashInferMetadata]): + pass + + +class PrefillOnlyFlashInferImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError("PrefillOnlyFlashInferImpl does not " + "support block-sparse attention.") + + from vllm_flash_attn import flash_attn_varlen_func + + self.flash_attn_varlen_func = flash_attn_varlen_func + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = ((sliding_window, sliding_window) + if sliding_window is not None else (-1, -1)) + if logits_soft_cap is None: + # In flash-attn, setting logits_soft_cap as 0 means no soft cap. + logits_soft_cap = 0 + self.logits_soft_cap = logits_soft_cap + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + if sliding_window is not None: + # NOTE(woosuk): flash-attn's sliding window does not work with + # paged KV cache. + raise ValueError( + "Sliding window is not supported in FlashAttention.") + + support_head_sizes = ( + PrefillOnlyFlashInferBackend.get_supported_head_sizes()) + if head_size not in support_head_sizes: + raise ValueError( + f"Head size {head_size} is not supported by FlashAttention. " + f"Supported head sizes are: {support_head_sizes}.") + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyFlashInferMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + """Forward pass with FlashAttention. + + Args: + query: shape = [num_tokens, num_heads * head_size] + key: shape = [num_tokens, num_kv_heads * head_size] + value: shape = [num_tokens, num_kv_heads * head_size] + attn_metadata: Metadata for attention. + Returns: + shape = [num_tokens, num_heads * head_size] + """ + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyFlashInferImpl") + + # NOTE(woosuk): FlashAttention does not support FP8 KV cache. + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in FlashAttention.") + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + # Because encode only models do not involve kv cache + # When using Flashinfer backend in encode only models, + # you are actually using FLASH ATTN backend + attn_output = self.flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=attn_metadata.seq_start_loc, + cu_seqlens_k=attn_metadata.seq_start_loc, + max_seqlen_q=attn_metadata.max_seq_len, + max_seqlen_k=attn_metadata.max_seq_len, + softmax_scale=self.scale, + causal=causal, + window_size=self.sliding_window, + alibi_slopes=self.alibi_slopes, + softcap=self.logits_soft_cap, + ) + + # Reshape the output tensor. + return attn_output.view(num_tokens, hidden_size) diff --git a/vllm/wde/prefill_only/layers/attention/backends/torch_naive.py b/vllm/wde/prefill_only/layers/attention/backends/torch_naive.py new file mode 100644 index 0000000000000..10319d4fecd41 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/torch_naive.py @@ -0,0 +1,161 @@ +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchNAIVEBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_naive" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchNaiveBackendImpl"]: + return PrefillOnlyTorchNaiveBackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyTorchNaiveMetadata"]: + return PrefillOnlyTorchNaiveMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyTorchNaiveMetadataBuilder + + +@dataclass +class PrefillOnlyTorchNaiveMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyTorchNaiveMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyTorchNaiveMetadata]): + pass + + +class PrefillOnlyTorchNaiveBackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch naive does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch naive does not support logits soft cap.") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyTorchNaiveMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in TorchNaive.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchNaiveBackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) + + +def scaled_dot_product_attention(query, + key, + value, + attn_mask=None, + is_causal=False, + scale=None) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool, + device=query.device).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias += attn_mask + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight @ value diff --git a/vllm/wde/prefill_only/layers/attention/backends/torch_sdpa.py b/vllm/wde/prefill_only/layers/attention/backends/torch_sdpa.py new file mode 100644 index 0000000000000..8a8cba9c2aaaa --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/torch_sdpa.py @@ -0,0 +1,135 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch +from torch.nn.functional import scaled_dot_product_attention + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyTorchSDPABackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "torch_sdpa" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyTorchSDPABackendImpl"]: + return PrefillOnlyTorchSDPABackendImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyTorchSDPAMetadata"]: + return PrefillOnlyTorchSDPAMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyAttentionMetadataBuilder"]: + return PrefillOnlyTorchSDPAMetadataBuilder + + +@dataclass +class PrefillOnlyTorchSDPAMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyTorchSDPAMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyTorchSDPAMetadata]): + pass + + +class PrefillOnlyTorchSDPABackendImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "Torch SPDA does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError("Torch SPDA does not support logits soft cap.") + + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + self.need_mask = (self.alibi_slopes is not None + or self.sliding_window is not None) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyTorchSDPAMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + assert k_scale == 1.0 and v_scale == 1.0, ( + "key/v_scale is not supported in TorchSDPA.") + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyTorchSDPABackendImpl") + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=1) + value = value.repeat_interleave(self.num_queries_per_kv, dim=1) + + query = query.movedim(0, query.dim() - 2) + key = key.movedim(0, key.dim() - 2) + value = value.movedim(0, value.dim() - 2) + + start = 0 + output = torch.empty((num_tokens, self.num_heads, self.head_size), + dtype=query.dtype, + device=query.device) + + for seq_len in attn_metadata.seq_lens: + end = start + seq_len + sub_out = scaled_dot_product_attention( + query[None, :, start:end, :], + key[None, :, start:end, :], + value[None, :, start:end, :], + dropout_p=0.0, + is_causal=causal, + scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) + output[start:end, :, :] = sub_out + start = end + + # Reshape the output tensor. + return output.view(-1, self.num_heads * self.head_size) diff --git a/vllm/wde/prefill_only/layers/attention/backends/xformers.py b/vllm/wde/prefill_only/layers/attention/backends/xformers.py new file mode 100644 index 0000000000000..03c9fd5166496 --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/backends/xformers.py @@ -0,0 +1,137 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Type + +import torch +from xformers import ops as xops +from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask, + BlockDiagonalMask) + +from vllm.logger import init_logger +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionBackend, PrefillOnlyAttentionImpl, + PrefillOnlyAttentionMetadata, PrefillOnlyAttentionMetadataBuilder) + +logger = init_logger(__name__) + + +class PrefillOnlyXFormersBackend(PrefillOnlyAttentionBackend): + + @staticmethod + def get_name() -> str: + return "xformers" + + @staticmethod + def get_impl_cls() -> Type["PrefillOnlyXFormersImpl"]: + return PrefillOnlyXFormersImpl + + @staticmethod + def get_metadata_cls() -> Type["PrefillOnlyAttentionMetadata"]: + return PrefillOnlyXFormersMetadata + + @staticmethod + def get_builder_cls() -> Type["PrefillOnlyXFormersMetadataBuilder"]: + return PrefillOnlyXFormersMetadataBuilder + + +@dataclass +class PrefillOnlyXFormersMetadata(PrefillOnlyAttentionMetadata): + pass + + +class PrefillOnlyXFormersMetadataBuilder( + PrefillOnlyAttentionMetadataBuilder[PrefillOnlyXFormersMetadata]): + pass + + +class PrefillOnlyXFormersImpl(PrefillOnlyAttentionImpl): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: Optional[int] = None, + alibi_slopes: Optional[List[float]] = None, + sliding_window: Optional[int] = None, + kv_cache_dtype: str = "auto", + blocksparse_params: Optional[Dict[str, Any]] = None, + logits_soft_cap: Optional[float] = None, + ) -> None: + if blocksparse_params is not None: + raise ValueError( + "XFormers does not support block-sparse attention.") + if logits_soft_cap is not None: + raise ValueError( + "XFormers does not support attention logits soft capping.") + self.num_heads = num_heads + self.head_size = head_size + self.scale = float(scale) + self.num_kv_heads = num_kv_heads + if alibi_slopes is not None: + alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) + self.alibi_slopes = alibi_slopes + self.sliding_window = sliding_window + assert self.alibi_slopes is None + assert self.num_heads % self.num_kv_heads == 0 + self.num_queries_per_kv = self.num_heads // self.num_kv_heads + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_metadata: PrefillOnlyXFormersMetadata, + kv_cache: Optional[torch.Tensor] = None, + k_scale: float = 1.0, + v_scale: float = 1.0, + attn_type: AttentionType = AttentionType.ENCODER, + ) -> torch.Tensor: + + if attn_type == AttentionType.ENCODER: + causal = False + elif attn_type == AttentionType.DECODER: + causal = True + else: + raise NotImplementedError("Encoder/decoder cross-attention " + "are not implemented for " + "PrefillOnlyXFormersImpl") + original_query = query + + # Reshape the query, key, and value tensors. + query = query.view(-1, self.num_heads, self.head_size) + key = key.view(-1, self.num_kv_heads, self.head_size) + value = value.view(-1, self.num_kv_heads, self.head_size) + + if self.num_kv_heads != self.num_heads: + # GQA/MQA requires the shape [B, M, G, H, K]. + # Note that the output also has the same shape (which is different + # from a spec from the doc). + query = query.view(query.shape[0], self.num_kv_heads, + self.num_queries_per_kv, query.shape[-1]) + key = key[:, :, + None, :].expand(key.shape[0], self.num_kv_heads, + self.num_queries_per_kv, key.shape[-1]) + value = value[:, :, + None, :].expand(value.shape[0], self.num_kv_heads, + self.num_queries_per_kv, + value.shape[-1]) + + if causal: + attn_bias = BlockDiagonalCausalMask.from_seqlens( + attn_metadata.seq_lens) + else: + attn_bias = BlockDiagonalMask.from_seqlens(attn_metadata.seq_lens) + + # Add the batch dimension. + query = query.unsqueeze(0) + key = key.unsqueeze(0) + value = value.unsqueeze(0) + + out = xops.memory_efficient_attention_forward(query, + key, + value, + p=0.0, + attn_bias=attn_bias, + scale=self.scale) + return out.view_as(original_query) diff --git a/vllm/wde/prefill_only/layers/attention/selector.py b/vllm/wde/prefill_only/layers/attention/selector.py new file mode 100644 index 0000000000000..c6234d1d82b3f --- /dev/null +++ b/vllm/wde/prefill_only/layers/attention/selector.py @@ -0,0 +1,126 @@ +import enum +from typing import Optional + +import torch + +import vllm.envs as envs +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.wde.core.layers.attention.abstract import AttentionType +from vllm.wde.core.llm_engine import LLMEngine + +logger = init_logger(__name__) + + +class _Backend(enum.Enum): + FLASH_ATTN = enum.auto() + XFORMERS = enum.auto() + ROCM_FLASH = enum.auto() + TORCH_SDPA = enum.auto() + OPENVINO = enum.auto() + FLASHINFER = enum.auto() + PALLAS = enum.auto() + IPEX = enum.auto() + TORCH_NAIVE = enum.auto() + + @staticmethod + def backend_name_to_enum(backend_name: str) -> "_Backend": + assert backend_name is not None + + backend_members = _Backend.__members__ + if backend_name not in backend_members: + raise ValueError( + f"Invalid attention backend '{backend_name}'. " + f"Available backends: {', '.join(backend_members)} " + "(case-sensitive).") + + return _Backend[backend_name] + + +class AttnBackend: + + @classmethod + def from_engine(cls, engine: LLMEngine): + model_config = engine.engine_config.model_config + num_heads = model_config.get_num_attention_heads() + head_size = model_config.get_head_size() + num_kv_heads = model_config.get_num_kv_heads() + sliding_window = model_config.get_sliding_window() + dtype = model_config.dtype + + backend = cls.which_attn_to_use(num_heads, head_size, num_kv_heads, + sliding_window, dtype) + + backend_cls = cls.get_backend_cls(backend) + + attn_type = AttentionType.attn_type_name_to_enum( + engine.workflow.attn_type) + + return backend_cls(attn_type) + + @staticmethod + def get_backend_cls(backend): + if backend == _Backend.FLASH_ATTN: + logger.info("Using FLASH ATTN backend.") + from vllm.wde.prefill_only.layers.attention.backends.flash_attn import ( # noqa: E501 + PrefillOnlyFlashAttentionBackend) + return PrefillOnlyFlashAttentionBackend + if backend == _Backend.XFORMERS: + logger.info("Using XFormers backend.") + from vllm.wde.prefill_only.layers.attention.backends.xformers import ( # noqa: E501 + PrefillOnlyXFormersBackend) + return PrefillOnlyXFormersBackend + elif backend == _Backend.TORCH_SDPA: + logger.info("Using Torch SDPA backend.") + from vllm.wde.prefill_only.layers.attention.backends.torch_sdpa import ( # noqa: E501 + PrefillOnlyTorchSDPABackend) + return PrefillOnlyTorchSDPABackend + elif backend == _Backend.FLASHINFER: + logger.info("Using Flashinfer backend.") + logger.info("When using Flashinfer backend in encode only models, " + "you are actually using FLASH ATTN backend") + from vllm.wde.prefill_only.layers.attention.backends.flashinfer import ( # noqa: E501 + PrefillOnlyFlashInferBackend) + return PrefillOnlyFlashInferBackend + elif backend == _Backend.TORCH_NAIVE: + logger.info("Using Torch naive backend.") + from vllm.wde.prefill_only.layers.attention.backends.torch_naive import ( # noqa: E501 + PrefillOnlyTorchNAIVEBackend) + return PrefillOnlyTorchNAIVEBackend + else: + raise ValueError("Invalid attention backend.") + + @classmethod + def which_attn_to_use(cls, num_heads: int, head_size: int, + num_kv_heads: int, sliding_window: Optional[int], + dtype: torch.dtype): + # Default case. + selected_backend = _Backend.FLASH_ATTN + + # get_env_variable_attn_backend + # Check the environment variable and override if specified + backend_by_env_var: Optional[str] = envs.VLLM_ATTENTION_BACKEND + if backend_by_env_var is not None: + selected_backend = _Backend.backend_name_to_enum( + backend_by_env_var) + + # FlashAttn in NVIDIA GPUs. + if selected_backend == _Backend.FLASH_ATTN: + if current_platform.get_device_capability()[0] < 8: + # Volta and Turing NVIDIA GPUs. + logger.info( + "Cannot use FlashAttention-2 backend for Volta and Turing " + "GPUs.") + selected_backend = _Backend.XFORMERS + elif dtype not in (torch.float16, torch.bfloat16): + logger.info( + "Cannot use FlashAttention-2 backend for dtype other than " + "torch.float16 or torch.bfloat16.") + selected_backend = _Backend.XFORMERS + elif sliding_window is not None: + logger.info( + "Cannot use FlashAttention-2 backend due to sliding window." + ) + selected_backend = _Backend.XFORMERS + + return selected_backend diff --git a/vllm/wde/prefill_only/processor/__init__.py b/vllm/wde/prefill_only/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/processor/input_processor.py b/vllm/wde/prefill_only/processor/input_processor.py new file mode 100644 index 0000000000000..3fcabb77b1f1b --- /dev/null +++ b/vllm/wde/prefill_only/processor/input_processor.py @@ -0,0 +1,66 @@ +import time +from typing import Optional + +from vllm.wde.core.inputs.tokenizer import Tokenizer +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.input_processor import (InputProcessor, + RequestProcessor) +from vllm.wde.core.schema.engine_io import (Params, PromptInput, TextPrompt, + TokensPrompt) +from vllm.wde.prefill_only.schema.engine_io import ( + PrefillOnlyInput, PrefillOnlyRequest, PrefillOnlySchedulableRequest) + + +class PrefillOnlyModelInputProcessor(InputProcessor): + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls() + + def __call__(self, + request_id: str, + inputs: Optional[PromptInput] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> PrefillOnlyRequest: + if not arrival_time: + arrival_time = time.time() + request = PrefillOnlyRequest(request_id=str(request_id), + inputs=inputs, + arrival_time=arrival_time) + return request + + +class PrefillOnlyModelRequestProcessor(RequestProcessor): + + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine.tokenizer) + + def __call__(self, + request: PrefillOnlyRequest) -> PrefillOnlySchedulableRequest: + inputs = request.inputs + + if isinstance(inputs, str): + inputs = {"prompt": inputs} + elif isinstance(input, TextPrompt): + inputs = {"prompt": inputs.prompt} + elif isinstance(input, TokensPrompt): + inputs = {"prompt_token_ids", inputs.prompt_token_ids} + + if "prompt_token_ids" not in inputs: + tokenizer = self.tokenizer + + prompt_token_ids = tokenizer.encode(inputs["prompt"]) + else: + prompt_token_ids = inputs["prompt_token_ids"] + + schedulable_request = PrefillOnlySchedulableRequest( + request_id=request.request_id, + inputs=PrefillOnlyInput(prompt_token_ids=prompt_token_ids, + prompt=inputs.get("prompt")), + arrival_time=request.arrival_time) + + return schedulable_request diff --git a/vllm/wde/prefill_only/processor/model_input_builder.py b/vllm/wde/prefill_only/processor/model_input_builder.py new file mode 100644 index 0000000000000..190c8dd8fb413 --- /dev/null +++ b/vllm/wde/prefill_only/processor/model_input_builder.py @@ -0,0 +1,52 @@ +import torch + +from vllm.utils import is_pin_memory_available +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.model_input_builder import ModelInputBuilder +from vllm.wde.core.schema.execute_io import ExecuteInput +from vllm.wde.prefill_only.layers.attention.backends.abstract import ( + PrefillOnlyAttentionMetadataBuilder) +from vllm.wde.prefill_only.schema.engine_io import PrefillOnlySchedulerOutput +from vllm.wde.prefill_only.schema.execute_io import ModelInputForGPU + +pin_memory = is_pin_memory_available() + + +class PrefillOnlyModelInputBuilder(ModelInputBuilder): + + def __init__( + self, + attention_metadata_builder: PrefillOnlyAttentionMetadataBuilder): + self.attention_metadata_builder = attention_metadata_builder + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine.attn_backend.get_builder_cls()()) + + def __call__(self, + scheduler_output: PrefillOnlySchedulerOutput) -> ExecuteInput: + input_tokens = [] + input_positions = [] + seq_lens = [] + for request in scheduler_output.requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + input_tokens.extend(prompt_token_ids) + input_positions.extend(list(range(0, n_tokens))) + seq_lens.append(n_tokens) + + input_ids = torch.tensor(input_tokens, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + positions = torch.tensor(input_positions, + dtype=torch.long, + pin_memory=pin_memory, + device="cpu") + attn_metadata = self.attention_metadata_builder(seq_lens) + + model_input = ModelInputForGPU(input_ids=input_ids, + positions=positions, + attn_metadata=attn_metadata) + + return ExecuteInput(worker_input=None, model_input=model_input) diff --git a/vllm/wde/prefill_only/processor/output_processor.py b/vllm/wde/prefill_only/processor/output_processor.py new file mode 100644 index 0000000000000..a0cefe51e1b07 --- /dev/null +++ b/vllm/wde/prefill_only/processor/output_processor.py @@ -0,0 +1,52 @@ +from typing import List, Sequence, Union + +import torch + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.output_processor import OutputProcessor +from vllm.wde.core.schema.engine_io import ValidationError +from vllm.wde.prefill_only.schema.engine_io import (PrefillOnlyRequestOutput, + PrefillOnlySchedulerOutput) + + +class PrefillOnlyModelOutputProcessor(OutputProcessor): + + def __init__(self): + pass + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls() + + def __call__( + self, scheduler_output: PrefillOnlySchedulerOutput, + execute_output: Union[torch.Tensor, Sequence[torch.Tensor]] + ) -> List[PrefillOnlyRequestOutput]: + if isinstance(execute_output, torch.Tensor): + request_outputs = [] + offset = 0 + for request in scheduler_output.requests: + prompt_token_ids = request.inputs.prompt_token_ids + n_tokens = len(prompt_token_ids) + request_outputs.append( + PrefillOnlyRequestOutput( + request_id=request.request_id, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=execute_output[offset:offset + n_tokens])) + offset += n_tokens + return request_outputs + elif isinstance(execute_output, (list, tuple)): + sequence_output, pooled_output = execute_output + request_outputs = [] + for request, outputs in zip(scheduler_output.requests, + pooled_output): + prompt_token_ids = request.inputs.prompt_token_ids + request_outputs.append( + PrefillOnlyRequestOutput(request_id=request.request_id, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=outputs)) + return request_outputs + else: + raise ValidationError("Output format not supported") diff --git a/vllm/wde/prefill_only/runner/__init__.py b/vllm/wde/prefill_only/runner/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/runner/model_runner.py b/vllm/wde/prefill_only/runner/model_runner.py new file mode 100644 index 0000000000000..e2e550d807d29 --- /dev/null +++ b/vllm/wde/prefill_only/runner/model_runner.py @@ -0,0 +1,60 @@ +import torch +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import DeviceMemoryProfiler, is_pin_memory_available +from vllm.wde.core.config import DeviceConfig, LoadConfig, ModelConfig +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig +from vllm.wde.prefill_only.schema.execute_io import ModelInputForGPU + +logger = init_logger(__name__) + + +class ModelRunner: + + def __init__( + self, + model_config: ModelConfig, + scheduler_config: PrefillOnlySchedulerConfig, + device_config: DeviceConfig, + load_config: LoadConfig, + attn_backend: AttentionBackend, + ): + self.model_config = model_config + self.scheduler_config = scheduler_config + self.device_config = device_config + self.load_config = load_config + self.attn_backend = attn_backend + self.device = self.device_config.device + self.pin_memory = is_pin_memory_available() + + # Lazy initialization + self.model: nn.Module # Set after load_model + + def load_model(self) -> None: + from vllm.wde.core.loader.loader import (get_model_loader, + initialize_model) + + logger.info("Starting to load model %s...", self.model_config.model) + with DeviceMemoryProfiler() as m: + loader = get_model_loader(self.load_config) + self.model = initialize_model(model_config=self.model_config, + load_config=self.load_config, + device_config=self.device_config, + attn_backend=self.attn_backend) + + loader.load_model(self.model, + model_config=self.model_config, + device_config=self.device_config) + + self.model_memory_usage = m.consumed_memory + logger.info("Loading model weights took %.4f GB", + self.model_memory_usage / float(2**30)) + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPU, + ): + return self.model(**model_input.to_dict()) diff --git a/vllm/wde/prefill_only/scheduler.py b/vllm/wde/prefill_only/scheduler.py new file mode 100644 index 0000000000000..ea6795b5b46b2 --- /dev/null +++ b/vllm/wde/prefill_only/scheduler.py @@ -0,0 +1,89 @@ +from dataclasses import dataclass, field +from typing import Set + +from vllm.logger import init_logger +from vllm.wde.core.scheduler import Scheduler +from vllm.wde.core.schema.engine_io import SchedulableRequest +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig +from vllm.wde.prefill_only.processor.input_processor import ( + PrefillOnlyModelRequestProcessor) +from vllm.wde.prefill_only.schema.engine_io import PrefillOnlySchedulerOutput + +logger = init_logger(__name__) + + +@dataclass +class SchedulingBudget: + token_budget: int + max_num_requests: int + _curr_requests: Set[str] = field(default_factory=set) + _num_batched_tokens: int = 0 + + def can_schedule(self, *, num_new_tokens: int, num_new_request: int = 1): + assert num_new_tokens != 0 + assert num_new_request != 0 + a = self.num_batched_tokens + num_new_tokens <= self.token_budget + b = self.num_curr_request + num_new_request <= self.max_num_requests + return a and b + + def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + if req_id in self._curr_requests: + return + + self._curr_requests.add(req_id) + self._num_batched_tokens += num_batched_tokens + + @property + def num_batched_tokens(self): + return self._num_batched_tokens + + @property + def num_curr_request(self): + return len(self._curr_requests) + + +class PrefillOnlyScheduler(Scheduler): + support_scheduling = ["sync_scheduling", "async_scheduling"] + + def __init__( + self, + scheduler_config: PrefillOnlySchedulerConfig, + request_processor: PrefillOnlyModelRequestProcessor, + ) -> None: + super().__init__(scheduler_config, request_processor) + + @classmethod + def from_engine(cls, engine): + return cls(engine.engine_config.scheduler_config, + engine.request_processor) + + def schedule(self): + budget = SchedulingBudget( + token_budget=self.scheduler_config.max_num_batched_tokens, + max_num_requests=self.scheduler_config.max_num_requests, + ) + + waiting_queue = self.waiting + + scheduler_outputs = [] + while waiting_queue: + request = waiting_queue[0] + if request.request_id in self.aborted_requests: + self.aborted_requests.remove(request.request_id) + waiting_queue.popleft() + continue + + if not isinstance(request, SchedulableRequest): + request = self.request_processor(request) + waiting_queue[0] = request + + num_new_tokens = request.num_new_tokens + + if not budget.can_schedule(num_new_tokens=num_new_tokens): + break + + budget.add_num_batched_tokens(request.request_id, num_new_tokens) + waiting_queue.popleft() + scheduler_outputs.append(request) + + return PrefillOnlySchedulerOutput(requests=scheduler_outputs) diff --git a/vllm/wde/prefill_only/schema/__init__.py b/vllm/wde/prefill_only/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/schema/engine_io.py b/vllm/wde/prefill_only/schema/engine_io.py new file mode 100644 index 0000000000000..a5b9625ac2175 --- /dev/null +++ b/vllm/wde/prefill_only/schema/engine_io.py @@ -0,0 +1,51 @@ +from dataclasses import dataclass +from typing import Iterable, List + +import torch + +from vllm.wde.core.schema.engine_io import (PromptInput, Request, + RequestOutput, SchedulableRequest, + SchedulerOutput, TextOnlyInputs) + + +@dataclass +class PrefillOnlyInput(TextOnlyInputs): + pass + + +@dataclass +class PrefillOnlyRequest(Request): + inputs: PromptInput + + +@dataclass +class PrefillOnlySchedulableRequest(SchedulableRequest): + inputs: TextOnlyInputs + + @property + def num_new_tokens(self): + return len(self.inputs.prompt_token_ids) + + +@dataclass +class PrefillOnlySchedulerOutput(SchedulerOutput): + requests: Iterable[PrefillOnlyRequest] + + def is_empty(self) -> bool: + return not self.requests + + +class PrefillOnlyRequestOutput(RequestOutput): + + def __init__(self, request_id: str, outputs: torch.Tensor, + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + def __repr__(self): + return (f"PrefillOnlyRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") diff --git a/vllm/wde/prefill_only/schema/execute_io.py b/vllm/wde/prefill_only/schema/execute_io.py new file mode 100644 index 0000000000000..26a419428a85b --- /dev/null +++ b/vllm/wde/prefill_only/schema/execute_io.py @@ -0,0 +1,26 @@ +from dataclasses import dataclass + +import torch + +from vllm.wde.core.layers.attention import AttentionMetadata +from vllm.wde.core.schema.execute_io import ExecuteInput, ModelInput + + +@dataclass +class ModelInputForGPU(ModelInput): + input_ids: torch.Tensor + positions: torch.Tensor + attn_metadata: AttentionMetadata + + def to(self, target_device, non_blocking=False): + for k in self.__dict__: + self.__dict__[k] = self.__dict__[k].to(device=target_device, + non_blocking=non_blocking) + + def to_dict(self): + return self.__dict__ + + +class PrefillOnlyExecuteInput(ExecuteInput): + worker_input = None + model_input: ModelInputForGPU diff --git a/vllm/wde/prefill_only/worker/__init__.py b/vllm/wde/prefill_only/worker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/prefill_only/worker/gpu_worker.py b/vllm/wde/prefill_only/worker/gpu_worker.py new file mode 100644 index 0000000000000..5c6014f90902f --- /dev/null +++ b/vllm/wde/prefill_only/worker/gpu_worker.py @@ -0,0 +1,137 @@ +import os +from typing import List, Optional + +import torch + +from vllm.model_executor.utils import set_random_seed +from vllm.platforms import current_platform +from vllm.wde.core.config import (DeviceConfig, EngineConfig, LoadConfig, + ModelConfig) +from vllm.wde.core.layers.attention import AttentionBackend +from vllm.wde.core.worker import WorkerBase +from vllm.wde.prefill_only.config import PrefillOnlySchedulerConfig +from vllm.wde.prefill_only.runner.model_runner import ModelRunner +from vllm.wde.prefill_only.schema.execute_io import PrefillOnlyExecuteInput + + +class Worker(WorkerBase): + + def __init__( + self, + engine_config: EngineConfig, + attn_backend: AttentionBackend, + ) -> None: + self.model_config: ModelConfig = engine_config.model_config + self.scheduler_config: PrefillOnlySchedulerConfig = ( + engine_config.scheduler_config) + self.device_config: DeviceConfig = engine_config.device_config + self.load_config: LoadConfig = engine_config.load_config + self.device = self.device_config.device + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + self.model_runner = ModelRunner(self.model_config, + self.scheduler_config, + self.device_config, self.load_config, + attn_backend) + + def init_device(self) -> None: + if self.device_config.device.type == "cuda": + # torch.distributed.all_reduce does not free the input tensor until + # the synchronization point. This causes the memory usage to grow + # as the number of all_reduce calls increases. This env var disables + # this behavior. + # Related issue: + # https://discuss.pytorch.org/t/cuda-allocation-lifetime-for-inputs-to-distributed-all-reduce/191573 + os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1" + + # This env var set by Ray causes exceptions with graph building. + os.environ.pop("NCCL_ASYNC_ERROR_HANDLING", None) + self.device = torch.device("cuda:0") + torch.cuda.set_device(self.device) + + _check_if_gpu_supports_dtype(self.model_config.dtype) + torch.cuda.empty_cache() + self.init_gpu_memory = torch.cuda.mem_get_info()[0] + else: + raise RuntimeError( + f"Not support device type: {self.device_config.device}") + + self.dirty_fix_distributed_environment() + + # Set random seed. + set_random_seed(self.model_config.seed) + + def dirty_fix_distributed_environment(self): + # This dirty_fix can make ParallelLinear etc. work properly. + # Why should tp and model layers be coupled together? + + import vllm.distributed.parallel_state + + fake_parallel_group = FakeGroupCoordinator() + vllm.distributed.parallel_state._TP = fake_parallel_group + vllm.distributed.parallel_state._PP = fake_parallel_group + + @torch.inference_mode + def load_model(self): + self.model_runner.load_model() + + @torch.inference_mode + def __call__(self, execute_input: PrefillOnlyExecuteInput): + output = self.model_runner.execute_model(execute_input.model_input) + return output + + +def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): + # Check if the GPU supports the dtype. + if torch_dtype == torch.bfloat16: + compute_capability = current_platform.get_device_capability() + if compute_capability[0] < 8: + gpu_name = torch.cuda.get_device_name() + raise ValueError( + "Bfloat16 is only supported on GPUs with compute capability " + f"of at least 8.0. Your {gpu_name} GPU has compute capability " + f"{compute_capability[0]}.{compute_capability[1]}. " + "You can use float16 instead by explicitly setting the" + "`dtype` flag in CLI, for example: --dtype=half.") + + +class FakeGroupCoordinator: + rank: int = 0 + ranks: List[int] = [0] + world_size: int = 1 + local_rank: int = 0 + rank_in_group: int = 0 + + def destroy(self): + pass + + @property + def first_rank(self): + return self.ranks[0] + + @property + def last_rank(self): + return self.ranks[-1] + + @property + def is_first_rank(self): + return self.rank == self.first_rank + + @property + def is_last_rank(self): + return self.rank == self.last_rank + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + return input_ + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + return input_ + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + return input_ diff --git a/vllm/wde/prefill_only/workflow.py b/vllm/wde/prefill_only/workflow.py new file mode 100644 index 0000000000000..c500a3ecdd60d --- /dev/null +++ b/vllm/wde/prefill_only/workflow.py @@ -0,0 +1,46 @@ +from vllm.wde.core.workflow import Workflow + + +class PrefillOnlyWorkflow(Workflow): + + InputProcessor: str = ("vllm.wde.prefill_only.processor." + "input_processor:PrefillOnlyModelInputProcessor") + RequestProcessor: str = ( + "vllm.wde.prefill_only.processor." + "input_processor:PrefillOnlyModelRequestProcessor") + OutputProcessor: str = ("vllm.wde.prefill_only.processor." + "output_processor:PrefillOnlyModelOutputProcessor") + ModelInputBuilder: str = ( + "vllm.wde.prefill_only.processor." + "model_input_builder:PrefillOnlyModelInputBuilder") + Worker: str = "vllm.wde.prefill_only.worker.gpu_worker:Worker" + Executor: str = "vllm.wde.prefill_only.executor.gpu_executor" + Scheduler: str = "vllm.wde.prefill_only.scheduler:PrefillOnlyScheduler" + AttnBackend: str = ("vllm.wde.prefill_only.layers." + "attention.selector:AttnBackend") + + @classmethod + def from_engine(cls, engine): + workflow = cls() + + if engine.engine_config.parallel_config is None: + if engine.engine_config.scheduler_config.scheduling in ["sync"]: + workflow.Executor += ":GPUExecutor" + elif engine.engine_config.scheduler_config.scheduling in [ + "async", "double_buffer" + ]: + workflow.Executor += ":GPUAsyncExecutor" + else: + assert engine.engine_config.parallel_config.data_parallel_size > 0 + assert engine.engine_config.scheduler_config.scheduling in [ + "async", "double_buffer" + ] + + engine.engine_config.scheduler_config.max_num_on_the_fly *= ( + engine.engine_config.parallel_config.data_parallel_size) + + workflow.Executor = ( + "vllm.wde.prefill_only.executor.gpu_data_parallelism_executor:" + "GPUDataParallelismExecutor") + + return workflow diff --git a/vllm/wde/reranker/__init__.py b/vllm/wde/reranker/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/reranker/modelzoo/__init__.py b/vllm/wde/reranker/modelzoo/__init__.py new file mode 100644 index 0000000000000..761ad1144f190 --- /dev/null +++ b/vllm/wde/reranker/modelzoo/__init__.py @@ -0,0 +1,8 @@ +TASK = "reranker" +WORKFLOW = "vllm.wde.reranker.workflow:RerankerWorkflow" + +# Architecture -> (task, module, class, workflow). +RERANKER_MODELS = { + "XLMRobertaForSequenceClassification": + (TASK, "bge_reranker_v2_m3", "BGERerankerV2M3", WORKFLOW), +} diff --git a/vllm/wde/reranker/modelzoo/bge_reranker_v2_m3.py b/vllm/wde/reranker/modelzoo/bge_reranker_v2_m3.py new file mode 100644 index 0000000000000..48d8d8fbe8df7 --- /dev/null +++ b/vllm/wde/reranker/modelzoo/bge_reranker_v2_m3.py @@ -0,0 +1,11 @@ +# Adapted from +# https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/reranker/modeling.py +# https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/flag_reranker.py +# FlagEmbedding is licensed under the MIT License. + +from vllm.wde.encode_only.modelzoo.xlm_roberta import ( + XLMRobertaForSequenceClassification) + + +class BGERerankerV2M3(XLMRobertaForSequenceClassification): + pass diff --git a/vllm/wde/reranker/processor/__init__.py b/vllm/wde/reranker/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/reranker/processor/input_processor.py b/vllm/wde/reranker/processor/input_processor.py new file mode 100644 index 0000000000000..b1ac3d48a6a0c --- /dev/null +++ b/vllm/wde/reranker/processor/input_processor.py @@ -0,0 +1,58 @@ +import time +from typing import Optional, Sequence + +from vllm.wde.core.inputs.tokenizer import Tokenizer +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.input_processor import (InputProcessor, + RequestProcessor) +from vllm.wde.core.schema.engine_io import Params, ValidationError +from vllm.wde.prefill_only.schema.engine_io import ( + PrefillOnlyInput, PrefillOnlySchedulableRequest) +from vllm.wde.reranker.schema.engine_io import (Pairs, RerankerInputs, + RerankerRequest) + + +class RerankerInputProcessor(InputProcessor): + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls() + + def __call__(self, + request_id: str, + inputs: Optional[RerankerInputs] = None, + params: Optional[Params] = None, + arrival_time: Optional[float] = None) -> RerankerRequest: + if not arrival_time: + arrival_time = time.time() + + if isinstance(inputs, Sequence): + if len(inputs) != 2: + raise ValidationError("Reranker model input must be pairs.") + inputs = Pairs(query=inputs[0], passage=inputs[1]) + + request = RerankerRequest(request_id=str(request_id), + inputs=inputs, + arrival_time=arrival_time) + return request + + +class RerankerRequestProcessor(RequestProcessor): + + def __init__(self, tokenizer: Tokenizer): + self.tokenizer = tokenizer + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls(engine.tokenizer) + + def __call__(self, + request: RerankerRequest) -> PrefillOnlySchedulableRequest: + text_pair = (request.inputs.query, request.inputs.passage) + prompt_token_ids = self.tokenizer.encode(text_pair) + schedulable_request = PrefillOnlySchedulableRequest( + request_id=request.request_id, + inputs=PrefillOnlyInput(prompt_token_ids=prompt_token_ids, + prompt=None), + arrival_time=request.arrival_time) + return schedulable_request diff --git a/vllm/wde/reranker/processor/output_processor.py b/vllm/wde/reranker/processor/output_processor.py new file mode 100644 index 0000000000000..9a17b0e859f4e --- /dev/null +++ b/vllm/wde/reranker/processor/output_processor.py @@ -0,0 +1,31 @@ +from typing import List + +import torch + +from vllm.wde.core.llm_engine import LLMEngine +from vllm.wde.core.processor.output_processor import OutputProcessor +from vllm.wde.prefill_only.schema.engine_io import PrefillOnlySchedulerOutput +from vllm.wde.reranker.schema.engine_io import RerankerRequestOutput + + +class RerankerOutputProcessor(OutputProcessor): + + def __init__(self): + pass + + @classmethod + def from_engine(cls, engine: LLMEngine): + return cls() + + def __call__(self, scheduler_output: PrefillOnlySchedulerOutput, + execute_output: torch.Tensor) -> List[RerankerRequestOutput]: + execute_output = execute_output.view(-1, ).cpu().numpy().tolist() + request_outputs = [] + for i, request in enumerate(scheduler_output.requests): + prompt_token_ids = request.inputs.prompt_token_ids + request_outputs.append( + RerankerRequestOutput(request_id=request.request_id, + prompt_token_ids=prompt_token_ids, + finished=True, + score=float(execute_output[i]))) + return request_outputs diff --git a/vllm/wde/reranker/schema/__init__.py b/vllm/wde/reranker/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/reranker/schema/engine_io.py b/vllm/wde/reranker/schema/engine_io.py new file mode 100644 index 0000000000000..bfc2a15ebedba --- /dev/null +++ b/vllm/wde/reranker/schema/engine_io.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass +from typing import List, Sequence, Union + +from vllm.wde.core.schema.engine_io import Inputs, Request, RequestOutput + + +@dataclass +class Pairs(Inputs): + query: str + passage: str + + +RerankerInputs = Union[Sequence, Pairs] + + +@dataclass +class RerankerRequest(Request): + inputs: Pairs + + +class RerankerRequestOutput(RequestOutput): + + def __init__(self, request_id: str, score: float, + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.score = score + + def __repr__(self): + return (f"RerankerRequestOutput(request_id='{self.request_id}', " + f"score={repr(self.score)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") diff --git a/vllm/wde/reranker/workflow.py b/vllm/wde/reranker/workflow.py new file mode 100644 index 0000000000000..b6919ddc96408 --- /dev/null +++ b/vllm/wde/reranker/workflow.py @@ -0,0 +1,10 @@ +from vllm.wde.encode_only.workflow import EncodeOnlyWorkflow + + +class RerankerWorkflow(EncodeOnlyWorkflow): + InputProcessor: str = ("vllm.wde.reranker.processor." + "input_processor:RerankerInputProcessor") + RequestProcessor: str = ("vllm.wde.reranker.processor." + "input_processor:RerankerRequestProcessor") + OutputProcessor: str = ("vllm.wde.reranker.processor." + "output_processor:RerankerOutputProcessor") diff --git a/vllm/wde/retriever/__init__.py b/vllm/wde/retriever/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/retriever/arg_utils.py b/vllm/wde/retriever/arg_utils.py new file mode 100644 index 0000000000000..4ad34c7fb97be --- /dev/null +++ b/vllm/wde/retriever/arg_utils.py @@ -0,0 +1,8 @@ +from vllm.wde.decode_only.arg_utils import DecodeOnlyEngineArgs + + +class RetrieverDecodeOnlyEngineArgs(DecodeOnlyEngineArgs): + + def __post_init__(self): + super().__post_init__() + self.output_last_hidden_states = True diff --git a/vllm/wde/retriever/modelzoo/__init__.py b/vllm/wde/retriever/modelzoo/__init__.py new file mode 100644 index 0000000000000..3e72eba478d92 --- /dev/null +++ b/vllm/wde/retriever/modelzoo/__init__.py @@ -0,0 +1,25 @@ +TASK = "retriever" +RETRIEVER_ENCODER_ONLY_WORKFLOW = ("vllm.wde.retriever.workflow:" + "RetrieverEncodeOnlyWorkflow") + +# Architecture -> (task, module, class, workflow). +RETRIEVER_ENCODER_ONLY_MODELS = { + "XLMRobertaModel": + (TASK, "bge_m3", "BGEM3Model", RETRIEVER_ENCODER_ONLY_WORKFLOW), + "BertModel": + (TASK, "bert_retriever", "BertRetriever", RETRIEVER_ENCODER_ONLY_WORKFLOW), +} + +RETRIEVER_DECODER_ONLY_WORKFLOW = ("vllm.wde.retriever.workflow:" + "RetrieverDecodeOnlyWorkflow") + +# Architecture -> (task, module, class, workflow). +RETRIEVER_DECODER_ONLY_MODELS = { + "MistralModel": (TASK, "llama_embedding", "LlamaEmbeddingModel", + RETRIEVER_DECODER_ONLY_WORKFLOW) +} + +RETRIEVER_MODELS = { + **RETRIEVER_ENCODER_ONLY_MODELS, + **RETRIEVER_DECODER_ONLY_MODELS +} diff --git a/vllm/wde/retriever/modelzoo/bert_retriever.py b/vllm/wde/retriever/modelzoo/bert_retriever.py new file mode 100644 index 0000000000000..ad7422f9e047e --- /dev/null +++ b/vllm/wde/retriever/modelzoo/bert_retriever.py @@ -0,0 +1,68 @@ +# Adapted from +# https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/flag_models.py +# FlagEmbedding is licensed under the MIT License. +# BertRetriever also supports Snowflake Arctic Embed (Family) +# Arctic is licensed under the Apache-2. + +from typing import Optional + +import torch +from torch import nn + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.wde.core.layers.attention import AttentionBackend, AttentionMetadata +from vllm.wde.encode_only.modelzoo.bert import (BertConfig, BertModel, + LoadWeightsMixin) + + +class BertRetriever(nn.Module, LoadWeightsMixin): + # bge v1.5 family + # Snowflake Arctic Embed (Family) + + prefix = "bert." + _ignore_weights_keys = [ + "bert.embeddings.position_ids", 'bert.pooler.dense.weight' + ] + + def __init__(self, + config: BertConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + sentence_pooling_method="cls", + normalized=True, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + self.sentence_pooling_method = sentence_pooling_method + assert self.sentence_pooling_method == 'cls' + self.normalized = normalized + + self.bert = BertModel(config, + attn_backend, + quant_config=quant_config, + add_pooling_layer=False) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + + sequence_output, pooled_output = self.bert( + input_ids, + positions, + attn_metadata, + ) + + seq_start_loc = attn_metadata.seq_start_loc + + dense_vecs = sequence_output[seq_start_loc[:-1]] + + if self.normalized: + dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1) + + return dense_vecs diff --git a/vllm/wde/retriever/modelzoo/bge_m3.py b/vllm/wde/retriever/modelzoo/bge_m3.py new file mode 100644 index 0000000000000..9aec85792d749 --- /dev/null +++ b/vllm/wde/retriever/modelzoo/bge_m3.py @@ -0,0 +1,61 @@ +# Adapted from +# https://github.com/FlagOpen/FlagEmbedding/blob/master/FlagEmbedding/BGE_M3/modeling.py +# FlagEmbedding is licensed under the MIT License. + +from typing import Optional + +import torch +from torch import nn + +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.wde.core.layers.attention import AttentionBackend, AttentionMetadata +from vllm.wde.encode_only.modelzoo.xlm_roberta import (LoadWeightsMixin, + XLMRobertaConfig, + XLMRobertaModel) + + +class BGEM3Model(nn.Module, LoadWeightsMixin): + _ignore_weights_keys = [ + "roberta.pooler.dense.weight", "roberta.pooler.dense.bias" + ] + + prefix = "roberta." + + def __init__(self, + config: XLMRobertaConfig, + attn_backend: AttentionBackend, + quant_config: Optional[QuantizationConfig] = None, + sentence_pooling_method="cls", + normalized=True, + *args, + **kwargs): + super().__init__() + self.config = config + self.quant_config = quant_config + self.sentence_pooling_method = sentence_pooling_method + assert self.sentence_pooling_method == 'cls' + self.normalized = normalized + self.roberta = XLMRobertaModel(config, attn_backend, quant_config) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + + sequence_output = self.roberta( + input_ids, + positions, + attn_metadata, + ) + + seq_start_loc = attn_metadata.seq_start_loc + + dense_vecs = sequence_output[seq_start_loc[:-1]] + + if self.normalized: + dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1) + + return dense_vecs diff --git a/vllm/wde/retriever/modelzoo/gte_qwen/__init__.py b/vllm/wde/retriever/modelzoo/gte_qwen/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/retriever/modelzoo/gte_qwen/arg_utils.py b/vllm/wde/retriever/modelzoo/gte_qwen/arg_utils.py new file mode 100644 index 0000000000000..ff6495d78b3a0 --- /dev/null +++ b/vllm/wde/retriever/modelzoo/gte_qwen/arg_utils.py @@ -0,0 +1,28 @@ +from dataclasses import dataclass + +from vllm.logger import init_logger +from vllm.wde.decode_only.arg_utils import (DecodeOnlyEngineArgs, + DecodeOnlyEngineConfig, + filter_unexpected_fields) + +logger = init_logger(__name__) + + +@filter_unexpected_fields +@dataclass +class Qwen2EngineArgs(DecodeOnlyEngineArgs): + + def create_engine_config(self) -> DecodeOnlyEngineConfig: + # gte-Qwen2 and Qwen2 use the same architecture name,Qwen2ForCausalLM. + # gte-Qwen2 family may have multiple different architectures. + # gte-Qwen2-1.5B-instruct, does not use enable bidirectional. + # I'm not sure if this is a bug + # gte-Qwen2-7B-instruct use enable bidirectional + if "gte-Qwen2-1.5B-instruct" in self.model: + self.output_last_hidden_states = True + elif "gte-Qwen2-7B-instruct" in self.model: + self.output_last_hidden_states = True + self.enable_bidirectional = True + + config = super().create_engine_config() + return config diff --git a/vllm/wde/retriever/modelzoo/gte_qwen/workflow.py b/vllm/wde/retriever/modelzoo/gte_qwen/workflow.py new file mode 100644 index 0000000000000..eac132e902d44 --- /dev/null +++ b/vllm/wde/retriever/modelzoo/gte_qwen/workflow.py @@ -0,0 +1,6 @@ +from vllm.wde.decode_only.workflow import DecodeOnlyWorkflow + + +class Qwen2Workflow(DecodeOnlyWorkflow): + EngineArgs: str = ("vllm.wde.retriever.modelzoo." + "gte_qwen.arg_utils:Qwen2EngineArgs") diff --git a/vllm/wde/retriever/modelzoo/llama_embedding.py b/vllm/wde/retriever/modelzoo/llama_embedding.py new file mode 100644 index 0000000000000..a6a5abf7256d6 --- /dev/null +++ b/vllm/wde/retriever/modelzoo/llama_embedding.py @@ -0,0 +1,76 @@ +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn + +from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.wde.core.layers.attention import AttentionMetadata +from vllm.wde.decode_only.modelzoo.llama import LlamaModel + + +class LlamaEmbeddingModel(nn.Module): + """A model that uses Llama with additional embedding functionalities. + + This class encapsulates the LlamaModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of LlamaModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__( + self, + **kwargs, + ) -> None: + super().__init__() + self.model = LlamaModel(**kwargs) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + attn_metadata: AttentionMetadata, + kv_caches: Optional[List[torch.Tensor]] = None, + inputs_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.model.forward(input_ids, positions, attn_metadata, + kv_caches, inputs_embeds) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + params_dict = dict(self.model.named_parameters()) + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + for (param_name, weight_name, shard_id) in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) diff --git a/vllm/wde/retriever/processor/__init__.py b/vllm/wde/retriever/processor/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/retriever/processor/output_processor.py b/vllm/wde/retriever/processor/output_processor.py new file mode 100644 index 0000000000000..2317a325178df --- /dev/null +++ b/vllm/wde/retriever/processor/output_processor.py @@ -0,0 +1,23 @@ +from typing import List + +import torch + +from vllm.wde.prefill_only.processor.output_processor import ( + PrefillOnlyModelOutputProcessor) +from vllm.wde.prefill_only.schema.engine_io import PrefillOnlySchedulerOutput +from vllm.wde.retriever.schema.engine_io import EmbeddingRequestOutput + + +class RetrieverModelOutputProcessor(PrefillOnlyModelOutputProcessor): + + def __call__(self, scheduler_output: PrefillOnlySchedulerOutput, + execute_output: torch.Tensor) -> List[EmbeddingRequestOutput]: + request_outputs = [] + for request, outputs in zip(scheduler_output.requests, execute_output): + prompt_token_ids = request.inputs.prompt_token_ids + request_outputs.append( + EmbeddingRequestOutput(request_id=request.request_id, + prompt_token_ids=prompt_token_ids, + finished=True, + outputs=outputs)) + return request_outputs diff --git a/vllm/wde/retriever/schema/__init__.py b/vllm/wde/retriever/schema/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/wde/retriever/schema/engine_io.py b/vllm/wde/retriever/schema/engine_io.py new file mode 100644 index 0000000000000..ef616eee515d9 --- /dev/null +++ b/vllm/wde/retriever/schema/engine_io.py @@ -0,0 +1,30 @@ +from typing import List + +import torch + +from vllm.wde.core.schema.engine_io import RequestOutput + + +class EmbeddingRequestOutput(RequestOutput): + + def __init__(self, request_id: str, outputs: torch.Tensor, + prompt_token_ids: List[int], finished: bool): + self.request_id = request_id + self.prompt_token_ids = prompt_token_ids + self.finished = finished + self.outputs = outputs + + def __repr__(self): + """ + Returns a string representation of an EmbeddingRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the EmbeddingRequestOutput instance. + """ + return (f"EmbeddingRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"finished={self.finished})") diff --git a/vllm/wde/retriever/workflow.py b/vllm/wde/retriever/workflow.py new file mode 100644 index 0000000000000..0dabc98961fd8 --- /dev/null +++ b/vllm/wde/retriever/workflow.py @@ -0,0 +1,12 @@ +from vllm.wde.decode_only.workflow import DecodeOnlyWorkflow +from vllm.wde.encode_only.workflow import EncodeOnlyWorkflow + + +class RetrieverEncodeOnlyWorkflow(EncodeOnlyWorkflow): + OutputProcessor: str = ("vllm.wde.retriever.processor." + "output_processor:RetrieverModelOutputProcessor") + + +class RetrieverDecodeOnlyWorkflow(DecodeOnlyWorkflow): + EngineArgs: str = ("vllm.wde.retriever.arg_utils:" + "RetrieverDecodeOnlyEngineArgs")