diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 86e952a903f36..79cd96edce59f 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -297,26 +297,33 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int32, device="cpu", ) - query_lens_tensor = torch.tensor(prefill_query_lens, - dtype=torch.int32, - device="cpu") - kv_lens_tensor = torch.tensor(prefill_seq_lens, - dtype=torch.int32, - device="cpu") - query_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - kv_start_loc = torch.zeros(input_data.num_prefills + 1, - dtype=torch.int32, - device="cpu") - torch.cumsum(query_lens_tensor, - dim=0, - dtype=torch.int32, - out=query_start_loc[1:]) - torch.cumsum(kv_lens_tensor, - dim=0, - dtype=torch.int32, - out=kv_start_loc[1:]) + query_start_loc: torch.Tensor + kv_start_loc: torch.Tensor + if input_data.seq_start_loc is not None and input_data.seq_start_loc is not None: + query_start_loc = input_data.query_start_loc[input_data.num_prefills + 1:] + kv_start_loc = input_data.seq_start_loc + else: + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) max_query_len = max(prefill_query_lens) max_kv_len = max(prefill_seq_lens) else: diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 251a103e60f06..4a72a47e752c8 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -6,7 +6,7 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionMetadata, AttentionType) -from vllm.vllm_flash_attn import flash_attn_varlen_func +# from vllm.vllm_flash_attn import flash_attn_varlen_func class FlashAttentionBackend(AttentionBackend): @@ -166,6 +166,7 @@ def forward( ) # Compute attention and update output up to `num_actual_tokens`. + from vllm.vllm_flash_attn import flash_attn_varlen_func flash_attn_varlen_func( q=query[:num_actual_tokens], k=key_cache, diff --git a/vllm/v1/core/scheduler.py b/vllm/v1/core/scheduler.py index a3e85c20cc664..a985f13433ecf 100644 --- a/vllm/v1/core/scheduler.py +++ b/vllm/v1/core/scheduler.py @@ -203,6 +203,9 @@ def schedule(self) -> "SchedulerOutput": num_computed_tokens -= 1 num_new_tokens = 1 computed_blocks.pop() + # if current request can't be fully scheduled, skip and don't schedule it + if num_new_tokens > token_budget: + break num_new_tokens = min(num_new_tokens, token_budget) assert num_new_tokens > 0 diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index 26fd650aee4b7..5a7800d10eab6 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -10,6 +10,7 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.outputs import PoolingRequestOutput, RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -21,6 +22,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.cpu_executor import CPUExecutor logger = init_logger(__name__) @@ -127,6 +129,8 @@ def shutdown(self): @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_cpu(): + return CPUExecutor distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) if distributed_executor_backend == "mp": diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index 1b3a9f12d009e..0aa161ebcb659 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -11,6 +11,7 @@ from vllm.lora.request import LoRARequest from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.outputs import RequestOutput +from vllm.platforms import current_platform from vllm.pooling_params import PoolingParams from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams @@ -21,6 +22,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.cpu_executor import CPUExecutor logger = init_logger(__name__) @@ -104,6 +106,8 @@ def from_engine_args( @classmethod def _get_executor_cls(cls, vllm_config: VllmConfig): + if current_platform.is_cpu(): + return CPUExecutor distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) if distributed_executor_backend == "mp": diff --git a/vllm/v1/executor/cpu_executor.py b/vllm/v1/executor/cpu_executor.py new file mode 100644 index 0000000000000..2ecfeca9d4483 --- /dev/null +++ b/vllm/v1/executor/cpu_executor.py @@ -0,0 +1,337 @@ +import os +from functools import partial +from typing import Any, Awaitable, List, Optional, Dict, Tuple, Union + + +from vllm.config import VllmConfig +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.multiproc_worker_utils import (ProcessWorkerWrapper, + ResultHandler, WorkerMonitor) +from vllm.logger import init_logger +from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.sequence import ExecuteModelRequest +from vllm.utils import (get_distributed_init_method, get_open_port, + make_async, enable_trace_function_call_for_thread, + resolve_obj_by_qualname, update_environment_variables) + +from vllm.v1.executor.abstract import Executor +from vllm.v1.worker.cpu_worker import CPUWorkerV1 +logger = init_logger(__name__) + + +class CPUExecutor(Executor): + + uses_ray: bool = False + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + assert self.device_config.device_type == "cpu" + # Reminder: Please update docs/source/serving/compatibility_matrix.rst + # If the feature combo become valid + assert self.lora_config is None, "cpu backend doesn't support LoRA" + + # + # Environment variables for CPU executor + # + # Disable torch async compiling which won't work with daemonic processes + os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1" + + # Intel OpenMP setting + ld_prealod_str = os.getenv("LD_PRELOAD", "") + if "libiomp5.so" in ld_prealod_str: + # The time(milliseconds) that a thread should wait after + # completing the execution of a parallel region, before sleeping. + os.environ['KMP_BLOCKTIME'] = "1" + # Prevents the CPU to run into low performance state + os.environ['KMP_TPAUSE'] = "0" + # Provides fine granularity parallelism + os.environ['KMP_FORKJOIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_PLAIN_BARRIER_PATTERN'] = "dist,dist" + os.environ['KMP_REDUCTION_BARRIER_PATTERN'] = "dist,dist" + + # To hint IPEX uses shared memory based AllReduce + os.environ["LOCAL_WORLD_SIZE"] = str( + self.parallel_config.tensor_parallel_size) + + # Multiprocessing-based executor does not support multi-node setting. + # Since it only works for single node, we can use the loopback address + # 127.0.0.1 for communication. + ip = "127.0.0.1" + port = get_open_port() + self.distributed_init_method = get_distributed_init_method(ip, port) + + is_async = isinstance(self, CPUExecutorAsync) + + world_size = self.parallel_config.tensor_parallel_size + result_handler = ResultHandler() + self.parallel_worker_tasks: Optional[Union[Any, Awaitable[Any]]] = None + self.workers = [] + + if is_async: + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + )) for rank in range(0, world_size) + ] + self.driver_worker = self.workers[0] + self.workers = self.workers[1:] + self.driver_method_invoker = _async_driver_method_invoker + else: + self.driver_worker = self._create_worker() + self.driver_method_invoker = _driver_method_invoker + + if world_size != 1: + self.workers = [ + ProcessWorkerWrapper( + result_handler, + partial( + self._create_worker, + rank=rank, + local_rank=rank, + )) for rank in range(1, world_size) + ] + + self.worker_monitor = None + if world_size != 1 or is_async: + if is_async: + async_worker_list = self.workers + [self.driver_worker] + else: + async_worker_list = self.workers + self.worker_monitor = WorkerMonitor(async_worker_list, + result_handler) + result_handler.start() + self.worker_monitor.start() + + self._run_workers("initialize") + self._run_workers("load_model") + + def _create_worker( + self, + local_rank: int = 0, + rank: int = 0, + ): + + wrapper = WorkerWrapperBaseV1(vllm_config=self.vllm_config) + + assert self.distributed_init_method is not None + + kwargs = dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=self.distributed_init_method, + # kv_cache_dtype=self.cache_config.cache_dtype, + is_driver_worker=rank == 0, + ) + wrapper.init_worker(**kwargs) + + return wrapper.worker + + def _run_workers( + self, + method: str, + *args, + async_run_remote_workers_only: bool = False, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers. + + Args: + async_run_remote_workers_only: If True the method will be run only + in the remote workers, not the driver worker. It will also be + run asynchronously and return a list of futures rather than + blocking on the results. + """ + + if max_concurrent_workers: + raise NotImplementedError( + "max_concurrent_workers is not supported yet.") + + # Start the workers first. + worker_outputs = [ + worker.execute_method(method, *args, **kwargs) + for worker in self.workers + ] + + if async_run_remote_workers_only: + # Just return futures + return worker_outputs + + driver_worker_output = self.driver_method_invoker( + self.driver_worker, method, *args, **kwargs) + + # Get the results of the workers. + return [driver_worker_output + ] + [output.get() for output in worker_outputs] + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks by invoking the + underlying worker. + """ + return self.driver_method_invoker(self.driver_worker, + "determine_num_available_blocks") + + def initialize(self, num_gpu_blocks: int, + num_cpu_blocks: int = 0) -> None: + """Initialize the KV cache by invoking the underlying worker. + """ + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + # NOTE: `cpu block` for CPU backend is located on CPU memory but is + # referred as `gpu block`. Because we want to reuse the existing block + # management procedure. + logger.info("# CPU blocks: %d", num_gpu_blocks) + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + + output = self.driver_method_invoker(self.driver_worker, + "execute_model", execute_model_req) + return output + + def stop_remote_worker_execution_loop(self) -> None: + if self.parallel_worker_tasks is None: + return + """ + Passing None will cause the driver to stop the model execution + loop running in each of the remote workers. + """ + self.driver_method_invoker(self.driver_worker, "execute_model", None) + parallel_worker_tasks = self.parallel_worker_tasks + self.parallel_worker_tasks = None + # Ensure that workers exit model loop cleanly + # (this will raise otherwise) + self._wait_for_tasks_completion(parallel_worker_tasks) + + def check_health(self) -> None: + """Raises an error if engine is unhealthy.""" + if self.worker_monitor is not None and not self.worker_monitor.is_alive( + ): + raise RuntimeError("Worker processes are not running") + + def shutdown(self): + if (worker_monitor := getattr(self, "worker_monitor", + None)) is not None: + worker_monitor.close() + + def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None: + """Wait for futures returned from _run_workers() with + async_run_remote_workers_only to complete.""" + for result in parallel_worker_tasks: + result.get() + + def profile(self, is_start=True): + raise NotImplementedError + + def collective_rpc(self, + method: str, + timeout: Optional[float] = None, + args: Tuple = (), + kwargs: Optional[Dict] = None) -> []: + raise NotImplementedError + + +class CPUExecutorAsync(CPUExecutor, ExecutorAsyncBase): + + async def execute_model_async( + self, + execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + output = await make_async(self.execute_model + )(execute_model_req=execute_model_req, ) + return output + + async def check_health_async(self) -> None: + self.check_health() + + +def _driver_method_invoker(driver, method: str, *args, **kwargs): + return getattr(driver, method)(*args, **kwargs) + + +def _async_driver_method_invoker(driver, method: str, *args, **kwargs): + return driver.execute_method(method, *args, **kwargs).get() + +class WorkerWrapperBaseV1: + """ + The whole point of this class is to lazily initialize the worker. + We first instantiate the WorkerWrapper, which remembers the worker module + and class name. Then, when we call `update_environment_variables`, and the + real initialization happens in `init_worker`. + """ + + def __init__( + self, + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + trust_remote_code = vllm_config.model_config.trust_remote_code + self.worker: Optional[CPUWorkerV1] = None + if trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + @staticmethod + def update_environment_variables(envs: Dict[str, str]) -> None: + key = 'CUDA_VISIBLE_DEVICES' + if key in envs and key in os.environ: + # overwriting CUDA_VISIBLE_DEVICES is desired behavior + # suppress the warning in `update_environment_variables` + del os.environ[key] + update_environment_variables(envs) + + def init_worker(self, *args, **kwargs): + """ + Here we inject some common logic before initializing the worker. + Arguments are passed to the worker class constructor. + """ + enable_trace_function_call_for_thread(self.vllm_config) + + # see https://github.com/NVIDIA/nccl/issues/1234 + os.environ['NCCL_CUMEM_ENABLE'] = '0' + + from vllm.plugins import load_general_plugins + load_general_plugins() + + worker_class = resolve_obj_by_qualname("vllm.v1.worker.cpu_worker.CPUWorkerV1") + self.worker = worker_class(*args, **kwargs) + assert self.worker is not None + + def execute_method(self, method, *args, **kwargs): + try: + target = self if self.worker is None else self.worker + executor = getattr(target, method) + return executor(*args, **kwargs) + except Exception as e: + # if the driver worker also execute methods, + # exceptions in the rest worker may cause deadlock in rpc like ray + # see https://github.com/vllm-project/vllm/issues/3455 + # print the error and inform the user to solve the error + msg = (f"Error executing method {method}. " + "This might cause deadlock in distributed execution.") + logger.exception(msg) + raise e + + def __getattr__(self, attr): + return getattr(self.worker, attr) diff --git a/vllm/v1/worker/cpu_model_runner.py b/vllm/v1/worker/cpu_model_runner.py new file mode 100644 index 0000000000000..cf72a56cead62 --- /dev/null +++ b/vllm/v1/worker/cpu_model_runner.py @@ -0,0 +1,262 @@ +import torch +import numpy as np +from typing import List, TYPE_CHECKING + +from vllm.v1.worker.gpu_model_runner import GPUModelRunner +from vllm.model_executor.model_loader import get_model +from vllm.v1.outputs import ModelRunnerOutput +from vllm.forward_context import set_forward_context +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder +from vllm.attention import get_attn_backend + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +class CPUModelRunner(GPUModelRunner): + # + def __init__(self, vllm_config): + super().__init__(vllm_config, vllm_config.device_config.device) + self.use_cuda_graph = False + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + self.kv_cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) if needs_attn_backend else None + self.input_data = ModelInputForCPUBuilder.ModelInputData(False) + self.chunked_prefill = True + self.att_metadata_builder = self.attn_backend.get_builder_cls()( + self) + + @torch.inference_mode() + def execute_model( + self, + kv_caches: List[torch.Tensor], + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self._update_states(scheduler_output) + + # Run the encoder. + self._execute_encoder(scheduler_output) + encoder_outputs = self._gather_encoder_outputs(scheduler_output) + + # Prepare the decoder inputs. + input_ids, attn_metadata, logits_indices = self._prepare_inputs( + scheduler_output) + num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + # only eager mode + num_input_tokens = num_scheduled_tokens + + # Get the inputs embeds. + if encoder_outputs: + inputs_embeds = self.model.get_input_embeddings( + input_ids, encoder_outputs) + else: + inputs_embeds = self.model.get_input_embeddings(input_ids) + # NOTE(woosuk): To unify token ids and soft tokens (vision embeddings), + # always use embeddings (rather than token ids) as input to the model. + # TODO(woosuk): Avoid the copy. Optimize. + self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds) + + # Run the decoder. + # Use persistent buffers for CUDA graphs. + with set_forward_context(attn_metadata, self.vllm_config): + hidden_states = self.model( + input_ids=None, + positions=self.positions[:num_input_tokens], + kv_caches=kv_caches, + attn_metadata=None, + inputs_embeds=self.inputs_embeds[:num_input_tokens], + ) + hidden_states = hidden_states[:num_scheduled_tokens] + hidden_states = hidden_states[logits_indices] + logits = self.model.compute_logits(hidden_states, None) + + # Sample the next token and get logprobs if needed. + sampling_metadata = self._prepare_sampling(scheduler_output) + sampler_output = self.model.sample( + logits=logits, + sampling_metadata=sampling_metadata, + ) + + # NOTE: CPU-GPU synchronization happens here. + sampled_token_ids = sampler_output.sampled_token_ids + # TODO(woosuk): The following loop can be slow since it iterates over + # the requests one by one. Optimize. + num_reqs = self.input_batch.num_reqs + for i, req_id in enumerate(self.input_batch.req_ids[:num_reqs]): + req_state = self.requests[req_id] + seq_len = (req_state.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + assert seq_len <= req_state.num_tokens + if seq_len == req_state.num_tokens: + # Append the sampled token to the output token ids. + token_id = sampled_token_ids[i] + self.input_batch.token_ids_cpu[i, seq_len] = token_id + req_state.output_token_ids.append(token_id) + else: + # Ignore the sampled token from the partial request. + # Rewind the generator state as if the token was not sampled. + generator = self.input_batch.generators.get(i) + if generator is not None: + # This relies on cuda-specific torch-internal impl details + generator.set_offset(generator.get_offset() - 4) + + if sampler_output.logprob_token_ids is None: + logprob_token_ids = None + else: + logprob_token_ids = sampler_output.logprob_token_ids + if sampler_output.logprobs is None: + logprobs = None + else: + logprobs = sampler_output.logprobs + model_runner_output = ModelRunnerOutput( + req_ids=self.input_batch.req_ids[:num_reqs], + req_id_to_index=self.input_batch.req_id_to_index, + sampled_token_ids=sampled_token_ids, + logprob_token_ids_cpu=logprob_token_ids, + logprobs_cpu=logprobs, + ) + return model_runner_output + + def load_model(self) -> None: + self.model = get_model(vllm_config=self.vllm_config) + + def _prepare_inputs(self, scheduler_output: "SchedulerOutput"): + total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens + assert total_num_scheduled_tokens > 0 + num_reqs = self.input_batch.num_reqs + assert num_reqs > 0 + num_prefills = 0 + num_prefill_tokens = 0 + num_decode_tokens = 0 + # + # OPTIMIZATION: Start copying the block table first. + # This way, we can overlap the copy with the following CPU operations. + # self.input_batch.block_table[:num_reqs].copy_( + # self.input_batch.block_table_cpu_tensor[:num_reqs], + # non_blocking=True) + + # Get the number of scheduled tokens for each request. + # TODO: The Python loop can be slow. Optimize. + num_scheduled_tokens = [] + max_num_scheduled_tokens = 0 + for req_id in self.input_batch.req_ids[:num_reqs]: + num_tokens = scheduler_output.num_scheduled_tokens[req_id] + if num_tokens == 1: + num_decode_tokens += 1 + else: + num_prefills += 1 + num_prefill_tokens += num_tokens + num_scheduled_tokens.append(num_tokens) + max_num_scheduled_tokens = max(max_num_scheduled_tokens, + num_tokens) + num_scheduled_tokens = np.array(num_scheduled_tokens, dtype=np.int32) + assert max_num_scheduled_tokens > 0 + + # calculate block tables info for cpu + block_tables = self.input_batch.block_table_cpu_tensor[:num_reqs] + decode_block_tables = block_tables[num_prefills:] + prefill_block_tables = block_tables[:num_prefills] + # Get request indices. + # E.g., [2, 5, 3] -> [0, 0, 1, 1, 1, 1, 1, 2, 2, 2] + indices = np.arange(num_reqs) + req_indices = np.repeat(indices, num_scheduled_tokens) + + # Get batched arange. + # E.g., [2, 5, 3] -> [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + arange_matrix = np.tile(np.arange(max_num_scheduled_tokens), + (num_reqs, 1)) + mask = arange_matrix < num_scheduled_tokens[:, np.newaxis] + arange = arange_matrix[mask] + + # Get positions. + positions = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + positions_np = positions.numpy() + np.add(self.input_batch.num_computed_tokens_cpu[req_indices], + arange, + out=positions_np) + + # Get token indices. + # E.g., [0, 1, 0, 1, 2, 3, 4, 0, 1, 2] + # -> [0, 1, M, M + 1, M + 2, M + 3, M + 4, 2 * M, 2 * M + 1, 2 * M + 2] + # where M is the max_model_len. + token_indices = positions_np + req_indices * self.max_model_len + token_indices = torch.from_numpy(token_indices) + input_ids = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.index_select(torch.from_numpy( + self.input_batch.token_ids_cpu).flatten(), + 0, + token_indices, + out=input_ids) + + # Calculate the slot mapping. + block_numbers = self.input_batch.block_table_cpu_tensor.flatten()[ + token_indices // self.block_size] + block_offsets = token_indices % self.block_size + slot_mapping = torch.empty((total_num_scheduled_tokens, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + torch.add(block_numbers * self.block_size, + block_offsets, + out=slot_mapping) + + # Prepare the attention metadata. + query_start_loc = torch.empty((num_reqs + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + query_start_loc_np = query_start_loc.numpy() + query_start_loc_np[0] = 0 + np.cumsum(num_scheduled_tokens, out=query_start_loc_np[1:]) + + seq_lens = (self.input_batch.num_computed_tokens_cpu[:num_reqs] + + num_scheduled_tokens) + max_seq_len = seq_lens.max() + seq_start_loc = torch.empty((num_prefills + 1, ), + dtype=torch.int32, + device="cpu", + pin_memory=self.pin_memory) + seq_start_loc_np = seq_start_loc.numpy() + seq_start_loc_np[0] = 0 + np.cumsum(seq_lens[:num_prefills], out=seq_start_loc_np[1:]) + + # input_ids = input_ids.to(self.device, non_blocking=True) + self.positions[:total_num_scheduled_tokens].copy_(positions, + non_blocking=True) + # build input_data for cpu + data = self.input_data + data.use_mrope = False + data.seq_lens = seq_lens + data.query_lens = num_scheduled_tokens + data.num_decode_tokens = num_decode_tokens + data.num_prefills = num_prefills + data.num_prefill_tokens = num_prefill_tokens + data.input_tokens = input_ids + data.max_decode_seq_len = max_seq_len #? + data.decode_block_tables = decode_block_tables + data.prefill_block_tables = prefill_block_tables + data.slot_mapping = slot_mapping + data.query_start_loc = query_start_loc[:num_prefills+1] + data.seq_start_loc = seq_start_loc + attn_metadata = self.att_metadata_builder.build( + data.seq_lens, data.query_lens, -1, -1) + # NOTE(woosuk): Due to chunked prefills, there can be at most 1 partial + # request in the batch. While we should not sample any token from this + # partial request, we do so for simplicity. We will ignore the sampled + # token from the partial request. + # TODO: Support prompt logprobs. + logits_indices = query_start_loc[1:] - 1 + return input_ids, attn_metadata, logits_indices \ No newline at end of file diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py new file mode 100644 index 0000000000000..2dfd3ddeff499 --- /dev/null +++ b/vllm/v1/worker/cpu_worker.py @@ -0,0 +1,295 @@ +"""A CPU worker class.""" +import gc +import os +from typing import List, TYPE_CHECKING, Optional, Tuple, Dict + +import torch +import torch.distributed + +import vllm.envs as envs +from vllm.attention import get_attn_backend +from vllm.config import CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, VllmConfig +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce) +from vllm.logger import init_logger +from vllm.model_executor import set_random_seed +from vllm.platforms import current_platform +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size +from vllm.v1.outputs import ModelRunnerOutput +from vllm.v1.worker.cpu_model_runner import CPUModelRunner + +logger = init_logger(__name__) + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +class CPUCacheEngine: + """Manages the KV cache for CPU backend. + + This class is responsible for initializing and managing CPU KV + caches. It also provides methods for performing KV cache operations, such + as copying. + """ + + def __init__(self, cache_config: CacheConfig, model_config: ModelConfig, + parallel_config: ParallelConfig, + device_config: DeviceConfig) -> None: + assert device_config.device_type == "cpu" + self.cache_config = cache_config + self.model_config = model_config + self.parallel_config = parallel_config + + self.head_size = model_config.get_head_size() + self.num_layers = model_config.get_num_layers(parallel_config) + self.num_heads = model_config.get_num_kv_heads(parallel_config) + + self.block_size = cache_config.block_size + # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks + # for CPU backend, because we want to reuse KV cache management + # in the scheduler. + self.num_cpu_blocks = cache_config.num_gpu_blocks + + if cache_config.cache_dtype == "auto": + self.dtype = model_config.dtype + else: + self.dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Get attention backend. + self.attn_backend = get_attn_backend( + self.model_config.get_head_size(), + self.model_config.dtype, + cache_config.cache_dtype, + self.block_size, + self.model_config.is_attention_free, + ) + + # Initialize the cache. + self.cpu_cache = self._allocate_kv_cache(self.num_cpu_blocks) + + def _allocate_kv_cache( + self, + num_blocks: int, + ) -> List[torch.Tensor]: + """Allocates KV cache on CPU.""" + kv_cache_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_heads, self.head_size) + kv_cache: List[torch.Tensor] = [] + for _ in range(self.num_layers): + kv_cache.append( + torch.empty(kv_cache_shape, dtype=self.dtype, device="cpu")) + return kv_cache + + def swap_in(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def swap_out(self, src_to_dst: Dict[int, int]) -> None: + raise NotImplementedError("Swap is not supported in CPUCacheEngine.") + + def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: + self.attn_backend.copy_blocks(self.cpu_cache, src_to_dsts) + + @staticmethod + def get_cache_block_size( + block_size: int, + cache_dtype: str, + model_config: ModelConfig, + parallel_config: ParallelConfig, + ) -> int: + head_size = model_config.get_head_size() + num_heads = model_config.get_num_kv_heads(parallel_config) + num_layers = model_config.get_num_layers(parallel_config) + + key_cache_block = block_size * num_heads * head_size + value_cache_block = key_cache_block + total = num_layers * (key_cache_block + value_cache_block) + if cache_dtype == "auto": + dtype = model_config.dtype + else: + dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + dtype_size = torch.tensor([], dtype=dtype).element_size() + return dtype_size * total + +class CPUWorkerV1: + + def __init__( + self, + vllm_config: VllmConfig, + local_rank: int, + rank: int, + distributed_init_method: str, + is_driver_worker: bool, + ): + + # TODO: use WorkerBase.__init__(self, vllm_config=vllm_config) + self.vllm_config = vllm_config + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.load_config = vllm_config.load_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config + self.prompt_adapter_config = vllm_config.prompt_adapter_config + self.observability_config = vllm_config.observability_config + + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method + + if self.model_config.trust_remote_code: + # note: lazy import to avoid importing torch before initializing + from vllm.utils import init_cached_hf_modules + init_cached_hf_modules() + + # Setup OpenMP threads affinity. + omp_cpuids = envs.VLLM_CPU_OMP_THREADS_BIND + if omp_cpuids == "all": + self.local_omp_cpuid = "all" + else: + self.local_omp_cpuid = omp_cpuids.split("|")[rank] + + self.model_runner = CPUModelRunner(vllm_config) + # Torch profiler. Enabled and configured through env vars: + self.profiler = None + # Uninitialized cache engine. Will be initialized by + # initialize_cache. + self.cache_engine: List[CPUCacheEngine] + # Initialize cpu_cache as embedding models don't initialize kv_caches + self.cpu_cache: Optional[List[List[torch.Tensor]]] = None + + def initialize(self) -> None: + if self.local_omp_cpuid != "all": + ret = torch.ops._C_utils.init_cpu_threads_env(self.local_omp_cpuid) + if ret: + logger.info(ret) + self.device = torch.device("cpu") + self.init_distributed_environment() + # Set random seed. + set_random_seed(self.model_config.seed) + + def load_model(self) -> None: + self.model_runner.load_model() + + @torch.inference_mode() + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of blocks available for the KV cache. + + This determines how many KV blocks can fit into the configured CPU + KV cache space. + + Note that since vLLM assumes a block resides on GPU if it can be + modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. + This allows us to reuse the scheduler of vLLM without generalizing it + to different devices. + """ + # For CPU device, the block number will be calculated based on the + # cpu_kvcache_space. + cache_block_size = self.get_cache_block_size_bytes() + num_cpu_blocks = int(self.cache_config.cpu_kvcache_space_bytes // + cache_block_size) + num_cpu_blocks = max(num_cpu_blocks, 0) + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_gpu_blocks = num_cpu_blocks + num_cpu_blocks = 0 + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache. Currently, swappable CPU memory is not + supported. + + Since this worker does not support GPUs, we use the num_gpu_blocks to + determine how many non-swappable CPU blocks to allocate. + """ + assert (num_cpu_blocks == 0 + ), f"{type(self)} does not support swappable cache" + + # Note: To reuse the cache management procedure, + # use cpu cache as 'gpu cache'. + num_cpu_blocks = num_gpu_blocks + + self._validate_num_cpu_blocks(num_cpu_blocks) + self.cache_config.num_gpu_blocks = num_cpu_blocks + self.cache_config.num_cpu_blocks = 0 + + # Initialize the cache. + self._init_cache_engine() + + def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: + """Raise errors if the num_cpu_blocks is invalid. + """ + if num_cpu_blocks <= 0: + raise ValueError("No available memory for the cache blocks. " + "Try increasing `VLLM_CPU_KVCACHE_SPACE` when " + "initializing the engine.") + + max_seq_len = self.cache_config.block_size * num_cpu_blocks + if self.model_config.max_model_len > max_seq_len: + raise ValueError( + f"The model's max seq len ({self.model_config.max_model_len}) " + "is larger than the maximum number of tokens that can be " + f"stored in KV cache ({max_seq_len}). Try increasing " + "`VLLM_CPU_KVCACHE_SPACE` or decreasing `max_model_len` when " + "initializing the engine.") + + def _init_cache_engine(self) -> None: + self.cache_engine = [ + CPUCacheEngine(self.cache_config, self.model_config, + self.parallel_config, self.device_config) + for _ in range(self.parallel_config.pipeline_parallel_size) + ] + self.cpu_cache = [ + self.cache_engine[ve].cpu_cache + for ve in range(self.parallel_config.pipeline_parallel_size) + ] + self.model_runner.block_size = self.cache_engine[0].block_size + + assert all( + self.cpu_cache[ve] is not None + for ve in range(self.parallel_config.pipeline_parallel_size)) + + # Populate the cache to warmup the memory + for ve in range(self.parallel_config.pipeline_parallel_size): + for layer_cache in self.cpu_cache[ve]: + layer_cache.fill_(0) + + @torch.inference_mode() + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + output = self.model_runner.execute_model(self.cpu_cache[0], scheduler_output) + # TODO(woosuk): Send the output to the engine process. + return output + + def init_distributed_environment(self) -> None: + """Initialize the distributed environment.""" + + parallel_config = self.parallel_config + rank = self.rank + distributed_init_method = self.distributed_init_method + init_distributed_environment( + world_size=parallel_config.world_size, + rank=rank, + distributed_init_method=distributed_init_method, + backend="gloo", + ) + + # A small all_reduce for warmup. + torch.distributed.all_reduce(torch.zeros(1).cpu()) + + ensure_model_parallel_initialized( + parallel_config.tensor_parallel_size, + parallel_config.pipeline_parallel_size) + + def get_cache_block_size_bytes(self) -> int: + """Return the size in bytes of a single KV cache block. + """ + return CPUCacheEngine.get_cache_block_size( + self.cache_config.block_size, self.cache_config.cache_dtype, + self.model_config, self.parallel_config) + diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 420aaf8a1b4cd..6ce1e4d27ec76 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -125,6 +125,9 @@ def __init__(self, use_mrope: bool): self.num_prefills: int = 0 self.num_prefill_tokens: int = 0 self.num_decode_tokens: int = 0 + # input from v1 + self.query_start_loc: Optional[torch.Tensor] = None + self.seq_start_loc: Optional[torch.Tensor] = None self.slot_mapping: List[int] = [] self.multi_modal_inputs_list: List[MultiModalKwargs] = [] self.multi_modal_placeholder_maps: Dict[