From a491d6f535d96939d17e5290991dc975495c9580 Mon Sep 17 00:00:00 2001 From: Rui Qiao <161574667+ruisearch42@users.noreply.github.com> Date: Mon, 23 Dec 2024 15:00:12 -0800 Subject: [PATCH] [V1] TP Ray executor (#11107) Signed-off-by: Rui Qiao --- .../test_basic_correctness.py | 2 +- vllm/v1/engine/llm_engine.py | 7 +- vllm/v1/executor/ray_executor.py | 339 ++++++++++++++++++ vllm/v1/executor/ray_utils.py | 271 ++++++++++++++ vllm/v1/worker/gpu_worker.py | 1 - 5 files changed, 617 insertions(+), 3 deletions(-) create mode 100644 vllm/v1/executor/ray_executor.py create mode 100644 vllm/v1/executor/ray_utils.py diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 11d05cefb7313..9e4eb16fc6cc5 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -130,7 +130,7 @@ def test_models_distributed( # Import VLLM_USE_V1 dynamically to handle patching from vllm.envs import VLLM_USE_V1 if VLLM_USE_V1 and distributed_executor_backend != "mp": - pytest.skip(f"Skip {distributed_executor_backend} for V1") + os.environ["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0" dtype = "half" max_tokens = 5 diff --git a/vllm/v1/engine/llm_engine.py b/vllm/v1/engine/llm_engine.py index bea8c5502f612..9ad51575b3cc3 100644 --- a/vllm/v1/engine/llm_engine.py +++ b/vllm/v1/engine/llm_engine.py @@ -21,6 +21,7 @@ from vllm.v1.engine.detokenizer import Detokenizer from vllm.v1.engine.processor import Processor from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.ray_utils import initialize_ray_cluster logger = init_logger(__name__) @@ -110,7 +111,11 @@ def _get_executor_cls(cls, vllm_config: VllmConfig) -> Type[Executor]: executor_class: Type[Executor] distributed_executor_backend = ( vllm_config.parallel_config.distributed_executor_backend) - if distributed_executor_backend == "mp": + if distributed_executor_backend == "ray": + initialize_ray_cluster(vllm_config.parallel_config) + from vllm.v1.executor.ray_executor import RayExecutor + executor_class = RayExecutor + elif distributed_executor_backend == "mp": from vllm.v1.executor.multiproc_executor import MultiprocExecutor executor_class = MultiprocExecutor else: diff --git a/vllm/v1/executor/ray_executor.py b/vllm/v1/executor/ray_executor.py new file mode 100644 index 0000000000000..dfeb69fa701a3 --- /dev/null +++ b/vllm/v1/executor/ray_executor.py @@ -0,0 +1,339 @@ +import os +from collections import defaultdict +from itertools import islice, repeat +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple + +import vllm.envs as envs +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils import get_distributed_init_method, get_ip, get_open_port +from vllm.v1.executor.abstract import Executor +from vllm.v1.executor.ray_utils import RayWorkerWrapper, ray +from vllm.v1.outputs import ModelRunnerOutput + +if ray is not None: + from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy + +if TYPE_CHECKING: + from ray.util.placement_group import PlacementGroup + +logger = init_logger(__name__) + + +class RayExecutor(Executor): + + def __init__(self, vllm_config: VllmConfig) -> None: + self.vllm_config = vllm_config + self.parallel_config = vllm_config.parallel_config + self.model_config = vllm_config.model_config + self.forward_dag: Optional[ray.dag.CompiledDAG] = None + + # Disable Ray usage stats collection. + ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0") + if ray_usage != "1": + os.environ["RAY_USAGE_STATS_ENABLED"] = "0" + + placement_group = self.parallel_config.placement_group + # Create the parallel GPU workers. + self._init_workers_ray(placement_group) + + def _init_workers_ray(self, placement_group: "PlacementGroup", + **ray_remote_kwargs): + # A list of workers to run a model. + self.workers: List[RayWorkerWrapper] = [] + if self.parallel_config.ray_workers_use_nsight: + ray_remote_kwargs = self._configure_ray_workers_use_nsight( + ray_remote_kwargs) + + # Create the workers. + driver_ip = get_ip() + for bundle_id, bundle in enumerate(placement_group.bundle_specs): + if not bundle.get("GPU", 0): + # Skip bundles that don't have GPUs, + # as each worker needs one GPU. + continue + scheduling_strategy = PlacementGroupSchedulingStrategy( + placement_group=placement_group, + placement_group_capture_child_tasks=True, + placement_group_bundle_index=bundle_id, + ) + + worker = ray.remote( + num_cpus=0, + num_gpus=1, + scheduling_strategy=scheduling_strategy, + **ray_remote_kwargs, + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) + self.workers.append(worker) + + logger.debug("workers: %s", self.workers) + worker_ips = [ + ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined] + for worker in self.workers + ] + ip_counts: Dict[str, int] = {} + for ip in worker_ips: + ip_counts[ip] = ip_counts.get(ip, 0) + 1 + + worker_to_ip = dict(zip(self.workers, worker_ips)) + + def sort_by_driver_then_worker_ip(worker): + """ + Sort the workers based on 3 properties: + 1. If the worker is on the same node as the driver (vllm engine), + it should be placed first. + 2. Then, if the worker is on a node with fewer workers, it should + be placed first. + 3. Finally, if the work is on a node with smaller IP address, it + should be placed first. This is simply a tiebreaker to make + sure the workers are sorted in a deterministic way. + """ + ip = worker_to_ip[worker] + return (ip != driver_ip, ip_counts[ip], ip) + + # After sorting, the workers on the same node will be + # close to each other, and the workers on the driver + # node will be placed first. + self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip) + + # Get the set of GPU IDs used on each node. + worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids") + + node_workers = defaultdict(list) # node id -> list of worker ranks + node_gpus = defaultdict(list) # node id -> list of gpu ids + + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids): + node_workers[node_id].append(i) + # `gpu_ids` can be a list of strings or integers. + # convert them to integers for consistency. + # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), + # string sorting is not sufficient. + # see https://github.com/vllm-project/vllm/issues/5590 + gpu_ids = [int(x) for x in gpu_ids] + node_gpus[node_id].extend(gpu_ids) + + for node_id, gpu_ids in node_gpus.items(): + node_gpus[node_id] = sorted(gpu_ids) + + all_ips = set(worker_ips) + n_ips = len(all_ips) + n_nodes = len(node_workers) + + if n_nodes != n_ips: + raise RuntimeError( + f"Every node should have a unique IP address. Got {n_nodes}" + f" nodes with node ids {list(node_workers.keys())} and " + f"{n_ips} unique IP addresses {all_ips}. Please check your" + " network configuration. If you set `VLLM_HOST_IP` or " + "`HOST_IP` environment variable, make sure it is unique for" + " each node.") + + # Set environment variables for the driver and workers. + all_args_to_update_environment_variables = [({ + "CUDA_VISIBLE_DEVICES": + ",".join(map(str, node_gpus[node_id])), + "VLLM_TRACE_FUNCTION": + str(envs.VLLM_TRACE_FUNCTION), + "VLLM_USE_V1": + str(int(envs.VLLM_USE_V1)), + **({ + "VLLM_ATTENTION_BACKEND": envs.VLLM_ATTENTION_BACKEND + } if envs.VLLM_ATTENTION_BACKEND is not None else {}) + }, ) for (node_id, _) in worker_node_and_gpu_ids] + + self._env_vars_for_all_workers = ( + all_args_to_update_environment_variables) + + self._run_workers("update_environment_variables", + all_args=self._get_env_vars_to_be_updated()) + + if len(node_gpus) == 1: + # in single node case, we don't need to get the IP address. + # the loopback address is sufficient + # NOTE: a node may have several IP addresses, one for each + # network interface. `get_ip()` might return any of them, + # while they might not work for communication inside the node + # if the network setup is complicated. Using the loopback address + # solves this issue, as it always works for communication inside + # the node. + driver_ip = "127.0.0.1" + distributed_init_method = get_distributed_init_method( + driver_ip, get_open_port()) + + # Initialize the actual workers inside worker wrapper. + init_worker_all_kwargs = [ + self._get_worker_kwargs( + local_rank=node_workers[node_id].index(rank), + rank=rank, + distributed_init_method=distributed_init_method, + ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids) + ] + self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) + self._run_workers("initialize") + self._run_workers("load_model") + + def _configure_ray_workers_use_nsight(self, + ray_remote_kwargs) -> Dict[str, Any]: + # If nsight profiling is enabled, we need to set the profiling + # configuration for the ray workers as runtime env. + runtime_env = ray_remote_kwargs.setdefault("runtime_env", {}) + runtime_env.update({ + "nsight": { + "t": "cuda,cudnn,cublas", + "o": "'worker_process_%p'", + "cuda-graph-trace": "node", + } + }) + + return ray_remote_kwargs + + def _get_env_vars_to_be_updated(self): + return self._env_vars_for_all_workers + + def _get_worker_kwargs( + self, + local_rank: int = 0, + rank: int = 0, + distributed_init_method: Optional[str] = None) -> Dict[str, Any]: + """ + Return worker init args for a given rank. + """ + if distributed_init_method is None: + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) + return dict( + vllm_config=self.vllm_config, + local_rank=local_rank, + rank=rank, + distributed_init_method=distributed_init_method, + ) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """ + Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks") + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize(self, num_gpu_blocks: int) -> None: + """ + Initialize the KV cache in all workers. + """ + # NOTE: This is logged in the executor because there can be >1 worker + # with other executors. We could log in the engine level, but work + # remains to abstract away the device for non-GPU configurations. + logger.info("# GPU blocks: %d", num_gpu_blocks) + self._run_workers("initialize_cache", num_gpu_blocks) + self._run_workers("compile_or_warm_up_model") + + def _run_workers( + self, + method: str, + *args, + all_args: Optional[List[Tuple[Any, ...]]] = None, + all_kwargs: Optional[List[Dict[str, Any]]] = None, + **kwargs, + ) -> Any: + """ + Runs the given method on all workers. Can be used in the following + ways: + + Args: + - args/kwargs: All workers share the same args/kwargs + - all_args/all_kwargs: args/kwargs for each worker are specified + individually + """ + count = len(self.workers) + all_worker_args = repeat(args, count) if all_args is None \ + else islice(all_args, 0, None) + all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \ + else islice(all_kwargs, 0, None) + + ray_worker_refs = [ + worker.execute_method.remote( # type: ignore[attr-defined] + method, *worker_args, **worker_kwargs) + for (worker, worker_args, worker_kwargs + ) in zip(self.workers, all_worker_args, all_worker_kwargs) + ] + return ray.get(ray_worker_refs) + + def execute_model( + self, + scheduler_output, + ) -> ModelRunnerOutput: + if self.forward_dag is None: + self.forward_dag = self._compiled_ray_dag() + # Only the first worker (with rank 0) returns the execution result. + # Others return None. + output = ray.get(self.forward_dag.execute(scheduler_output))[0] + return output + + def profile(self, is_start=True): + raise NotImplementedError + + def shutdown(self): + if hasattr(self, "forward_dag") and self.forward_dag is not None: + self.forward_dag.teardown() + import ray + for worker in self.workers: + ray.kill(worker) + self.forward_dag = None + + def check_health(self) -> None: + logger.debug("Called check_health.") + + def _check_ray_compiled_graph_installation(self): + import pkg_resources + from packaging import version + + required_version = version.parse("2.39") + current_version = version.parse( + pkg_resources.get_distribution("ray").version) + if current_version < required_version: + raise ValueError(f"Ray version {required_version} is " + f"required, but found {current_version}") + + import importlib.util + raycg = importlib.util.find_spec("ray.experimental.compiled_dag_ref") + if raycg is None: + raise ValueError("Ray Compiled Graph is not installed. " + "Run `pip install ray[adag]` to install it.") + + cupy_spec = importlib.util.find_spec("cupy") + if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: + raise ValueError( + "cupy is not installed but required since " + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set." + "Run `pip install ray[adag]` and check cupy installation.") + + def _compiled_ray_dag(self): + assert self.parallel_config.use_ray + self._check_ray_compiled_graph_installation() + from ray.dag import InputNode, MultiOutputNode + + with InputNode() as input_batches: + outputs = [ + worker.execute_model.bind( # type: ignore[attr-defined] + input_batches) for worker in self.workers + ] + forward_dag = MultiOutputNode(outputs) + + return forward_dag.experimental_compile() + + def __del__(self): + self.shutdown() diff --git a/vllm/v1/executor/ray_utils.py b/vllm/v1/executor/ray_utils.py new file mode 100644 index 0000000000000..7733610e59c7f --- /dev/null +++ b/vllm/v1/executor/ray_utils.py @@ -0,0 +1,271 @@ +import time +from collections import defaultdict +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple + +from vllm.config import ParallelConfig +from vllm.logger import init_logger +from vllm.platforms import current_platform +from vllm.utils import get_ip +from vllm.v1.outputs import ModelRunnerOutput +from vllm.worker.worker_base import WorkerWrapperBase + +if TYPE_CHECKING: + from vllm.v1.core.scheduler import SchedulerOutput + +logger = init_logger(__name__) +PG_WAIT_TIMEOUT = 60 + +try: + import ray + from ray.util import placement_group_table + from ray.util.placement_group import PlacementGroup + try: + from ray._private.state import available_resources_per_node + except ImportError: + # Ray 2.9.x doesn't expose `available_resources_per_node` + from ray._private.state import state as _state + available_resources_per_node = _state._available_resources_per_node + + class RayWorkerWrapper(WorkerWrapperBase): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + # Since the compiled DAG runs a main execution + # in a different thread that calls cuda.set_device. + # The flag indicates is set_device is called on + # that thread. It will be removed soon. + self.compiled_dag_cuda_device_set = False + + def get_node_ip(self) -> str: + return get_ip() + + def get_node_and_gpu_ids(self) -> Tuple[str, List[int]]: + node_id = ray.get_runtime_context().get_node_id() + gpu_ids = ray.get_gpu_ids() + return node_id, gpu_ids + + def setup_device_if_necessary(self): + # TODO(swang): This is needed right now because Ray CG executes + # on a background thread, so we need to reset torch's current + # device. + # We can remove this API after it is fixed in compiled graph. + import torch + assert self.worker is not None, "Worker is not initialized" + if not self.compiled_dag_cuda_device_set: + torch.cuda.set_device(self.worker.device) + self.compiled_dag_cuda_device_set = True + + def execute_model( + self, + scheduler_output: "SchedulerOutput", + ) -> ModelRunnerOutput: + self.setup_device_if_necessary() + assert self.worker is not None, "Worker is not initialized" + output = self.worker.model_runner.execute_model(scheduler_output) + return output + + ray_import_err = None + +except ImportError as e: + ray = None # type: ignore + ray_import_err = e + RayWorkerWrapper = None # type: ignore + + +def ray_is_available() -> bool: + """Returns True if Ray is available.""" + return ray is not None + + +def assert_ray_available(): + """ + Raise an exception if Ray is not available. + """ + if ray is None: + raise ValueError("Failed to import Ray, please install Ray with " + "`pip install ray`.") from ray_import_err + + +def _verify_bundles(placement_group: "PlacementGroup", + parallel_config: ParallelConfig, device_str: str): + """ + Verify a given placement group has bundles located in the right place. + + There are 2 rules. + - Warn if all tensor parallel workers cannot fit in a single node. + - Fail if driver node is not included in a placement group. + + Args: + placement_group: The placement group to verify. + parallel_config: The parallel configuration. + device_str: The required device. + """ + assert ray.is_initialized(), ( + "Ray is not initialized although distributed-executor-backend is ray.") + pg_data = placement_group_table(placement_group) + # bundle_idx -> node_id + bundle_to_node_ids = pg_data["bundles_to_node_id"] + # bundle_idx -> bundle (e.g., {"GPU": 1}) + bundles = pg_data["bundles"] + # node_id -> List of bundle (e.g., {"GPU": 1}) + node_id_to_bundle: Dict[str, List[Dict[str, float]]] = defaultdict(list) + + for bundle_idx, node_id in bundle_to_node_ids.items(): + node_id_to_bundle[node_id].append(bundles[bundle_idx]) + driver_node_id = ray.get_runtime_context().get_node_id() + + if driver_node_id not in node_id_to_bundle: + raise RuntimeError( + f"driver node id {driver_node_id} is not included in a placement " + f"group {placement_group.id}. Node id -> bundles " + f"{node_id_to_bundle}. " + "You don't have enough GPUs available in a current node. Check " + "`ray status` to see if you have available GPUs in a node " + f"{driver_node_id} before starting an vLLM engine.") + + for node_id, bundles in node_id_to_bundle.items(): + if len(bundles) < parallel_config.tensor_parallel_size: + logger.warning( + "tensor_parallel_size=%d " + "is bigger than a reserved number of %ss (%d " + "%ss) in a node %s. Tensor parallel workers can be " + "spread out to 2+ nodes which can degrade the performance " + "unless you have fast interconnect across nodes, like " + "Infiniband. To resolve this issue, make sure you have more " + "than %d GPUs available at each node.", + parallel_config.tensor_parallel_size, device_str, len(bundles), + device_str, node_id, parallel_config.tensor_parallel_size) + + +def _wait_until_pg_ready(current_placement_group: "PlacementGroup"): + """Wait until a placement group is ready. + + It prints the informative log messages if the placement group is + not created within time. + + """ + # Wait until PG is ready - this will block until all + # requested resources are available, and will timeout + # if they cannot be provisioned. + placement_group_specs = current_placement_group.bundle_specs + + s = time.time() + pg_ready_ref = current_placement_group.ready() + wait_interval = 10 + while time.time() - s < PG_WAIT_TIMEOUT: + ready, _ = ray.wait([pg_ready_ref], timeout=wait_interval) + if len(ready) > 0: + break + + # Exponential backoff for warning print. + wait_interval *= 2 + logger.info( + "Waiting for creating a placement group of specs for " + "%d seconds. specs=%s. Check " + "`ray status` to see if you have enough resources.", + int(time.time() - s), placement_group_specs) + + try: + ray.get(pg_ready_ref, timeout=0) + except ray.exceptions.GetTimeoutError: + raise ValueError( + "Cannot provide a placement group of " + f"{placement_group_specs=} within {PG_WAIT_TIMEOUT} seconds. See " + "`ray status` to make sure the cluster has enough resources." + ) from None + + +def initialize_ray_cluster( + parallel_config: ParallelConfig, + ray_address: Optional[str] = None, +): + """Initialize the distributed cluster with Ray. + + it will connect to the Ray cluster and create a placement group + for the workers, which includes the specification of the resources + for each distributed worker. + + Args: + parallel_config: The configurations for parallel execution. + ray_address: The address of the Ray cluster. If None, uses + the default Ray cluster address. + """ + assert_ray_available() + + # Connect to a ray cluster. + if current_platform.is_rocm() or current_platform.is_xpu(): + # Try to connect existing ray instance and create a new one if not found + try: + ray.init("auto") + except ConnectionError: + logger.warning( + "No existing RAY instance detected. " + "A new instance will be launched with current node resources.") + ray.init(address=ray_address, + ignore_reinit_error=True, + num_gpus=parallel_config.world_size) + else: + ray.init(address=ray_address, ignore_reinit_error=True) + + if parallel_config.placement_group: + # Placement group is already set. + return + + device_str = "GPU" if not current_platform.is_tpu() else "TPU" + # Create placement group for worker processes + current_placement_group = ray.util.get_current_placement_group() + if current_placement_group: + # We are in a placement group + bundles = current_placement_group.bundle_specs + # Verify that we can use the placement group. + device_bundles = 0 + for bundle in bundles: + bundle_devices = bundle.get(device_str, 0) + if bundle_devices > 1: + raise ValueError( + "Placement group bundle cannot have more than 1 " + f"{device_str}.") + if bundle_devices: + device_bundles += 1 + if parallel_config.world_size > device_bundles: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group." + f"Required number of devices: {parallel_config.world_size}. " + f"Total number of devices: {device_bundles}.") + else: + num_devices_in_cluster = ray.cluster_resources().get(device_str, 0) + if parallel_config.world_size > num_devices_in_cluster: + raise ValueError( + f"The number of required {device_str}s exceeds the total " + f"number of available {device_str}s in the placement group.") + # Create a new placement group + placement_group_specs: List[Dict[str, float]] = ([{ + device_str: 1.0 + } for _ in range(parallel_config.world_size)]) + + # vLLM engine is also a worker to execute model with an accelerator, + # so it requires to have the device in a current node. Check if + # the current node has at least one device. + current_ip = get_ip() + current_node_id = ray.get_runtime_context().get_node_id() + current_node_resource = available_resources_per_node()[current_node_id] + if current_node_resource.get(device_str, 0) < 1: + raise ValueError( + f"Current node has no {device_str} available. " + f"{current_node_resource=}. vLLM engine cannot start without " + f"{device_str}. Make sure you have at least 1 {device_str} " + f"available in a node {current_node_id=} {current_ip=}.") + # This way, at least bundle is required to be created in a current + # node. + placement_group_specs[0][f"node:{current_ip}"] = 0.001 + + # By default, Ray packs resources as much as possible. + current_placement_group = ray.util.placement_group( + placement_group_specs, strategy="PACK") + _wait_until_pg_ready(current_placement_group) + + assert current_placement_group is not None + _verify_bundles(current_placement_group, parallel_config, device_str) + # Set the placement group in the parallel config + parallel_config.placement_group = current_placement_group diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 33491f700de10..0000b09bfaa36 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -202,7 +202,6 @@ def execute_model( ) -> ModelRunnerOutput: output = self.model_runner.execute_model(scheduler_output) return output if self.rank == 0 else None - return output def profile(self, is_start: bool = True): if self.profiler is None: