From 573d228feb262984ade1f13dedf927e55165e5de Mon Sep 17 00:00:00 2001 From: "rshaw@neuralmagic.com" Date: Thu, 12 Dec 2024 22:05:05 +0000 Subject: [PATCH] e2e working --- examples/openai_completion_client.py | 2 +- vllm/engine/multiprocessing/__init__.py | 1 - vllm/engine/multiprocessing/client.py | 49 ++++++++++++++++--------- vllm/engine/multiprocessing/engine.py | 48 ++++++++++++++---------- vllm/engine/multiprocessing/ipc.py | 43 ++++++++++++---------- vllm/entrypoints/openai/api_server.py | 8 +++- 6 files changed, 91 insertions(+), 60 deletions(-) diff --git a/examples/openai_completion_client.py b/examples/openai_completion_client.py index 8effc00120d43..4d956c4beea33 100644 --- a/examples/openai_completion_client.py +++ b/examples/openai_completion_client.py @@ -2,7 +2,7 @@ # Modify OpenAI's API key and API base to use vLLM's API server. openai_api_key = "EMPTY" -openai_api_base = "http://localhost:8001/v1" +openai_api_base = "http://localhost:8002/v1" client = OpenAI( # defaults to os.environ.get("OPENAI_API_KEY") diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index af003bc6eff15..420f540d0b5f4 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -14,7 +14,6 @@ VLLM_RPC_SUCCESS_STR = "SUCCESS" - IPC_INPUT_EXT = "_input_socket" IPC_OUTPUT_EXT = "_output_socket" IPC_HEALTH_EXT = "_health_socket" diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8eba0caca5c2b..8988e6051f119 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -28,7 +28,7 @@ RPCError, RPCProcessRequest, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) -from vllm.engine.multiprocessing.ipc import (send_signed_async, +from vllm.engine.multiprocessing.ipc import (send_signed_async, recv_signed_async) from vllm.engine.protocol import EngineClient @@ -84,8 +84,9 @@ class MQLLMEngineClient(EngineClient): """ def __init__(self, ipc_path: str, engine_config: VllmConfig, - engine_pid: int): + engine_pid: int, secret_key: bytes): self.context = zmq.asyncio.Context() + self.secret_key = secret_key self._errored_with: Optional[BaseException] = None # Get the configs. @@ -163,6 +164,7 @@ async def run_heartbeat_loop(self, timeout: int): # Heartbeat received- check the message await self._check_success( error_message="Heartbeat failed.", + secret_key=self.secret_key, socket=self.heartbeat_socket) logger.debug("Heartbeat successful.") @@ -195,7 +197,8 @@ async def run_output_handler_loop(self): ENGINE_DEAD_ERROR(self._errored_with)) return - message = await recv_signed_async(self.output_socket) + message = await recv_signed_async(self.output_socket, + self.secret_key) request_outputs = pickle.loads(message) is_error = isinstance(request_outputs, @@ -290,11 +293,12 @@ def _set_errored(self, e: BaseException): async def _send_get_data_rpc_request(request: RPCStartupRequest, expected_type: Any, error_message: str, - socket: Socket) -> Any: + socket: Socket, + secret_key: bytes) -> Any: """Send an RPC request that is expecting data back.""" # Ping RPCServer with a request. - await send_signed_async(socket, pickle.dumps(request)) + await send_signed_async(socket, secret_key, pickle.dumps(request)) # Make sure the server responds in time. if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0: @@ -302,8 +306,8 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, f"{VLLM_RPC_TIMEOUT} ms") # Await the data from the Server. - frame = await recv_signed_async(socket) - data = pickle.loads(frame.buffer) + message = await recv_signed_async(socket, secret_key) + data = pickle.loads(message) if isinstance(data, BaseException): raise data @@ -314,13 +318,14 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest, @staticmethod async def _send_one_way_rpc_request(request: RPC_REQUEST_T, - socket: Socket): + socket: Socket, + secret_key: bytes): """Send one-way RPC request to trigger an action.""" if socket.closed: raise MQClientClosedError() - await send_signed_async(socket, pickle.dumps(request)) + await send_signed_async(socket, secret_key, pickle.dumps(request)) async def _await_ack(self, error_message: str, socket: Socket): """Await acknowledgement that a request succeeded.""" @@ -332,17 +337,18 @@ async def _await_ack(self, error_message: str, socket: Socket): raise TimeoutError("MQLLMEngine didn't reply within " f"{VLLM_RPC_TIMEOUT}ms") - await self._check_success(error_message, socket) + await self._check_success(error_message, self.secret_key, socket) @staticmethod - async def _check_success(error_message: str, socket: Socket): + async def _check_success(error_message: str, secret_key: bytes, + socket: Socket): """Confirm that socket has a VLLM_RPC_SUCCESS_STR message""" if socket.closed: raise MQClientClosedError() - frame = await recv_signed_async(socket, error_message) - response = pickle.loads(frame.buffer) + message = await recv_signed_async(socket, secret_key) + response = pickle.loads(message) # Raise error if unsuccessful if isinstance(response, BaseException): @@ -373,6 +379,7 @@ async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse: request=RPCStartupRequest.IS_SERVER_READY, expected_type=RPCStartupResponse, error_message="Unable to start RPC Server", + secret_key=self.secret_key, socket=socket) async def abort(self, request_id: str): @@ -627,9 +634,11 @@ async def _process_request( )) # 3) Send the RPCGenerateRequest to the MQLLMEngine. - parts = (request_bytes, - lp_bytes) if lp_bytes else (request_bytes, ) - await send_signed_async(self.input_socket, parts) + + # parts = (request_bytes, + # lp_bytes) if lp_bytes else (request_bytes, ) + await send_signed_async(self.input_socket, self.secret_key, + request_bytes) # 4) Stream the RequestOutputs from the output queue. Note # that the output_loop pushes RequestOutput objects to this @@ -655,10 +664,14 @@ async def start_profile(self) -> None: """Start profiling the engine""" await self._send_one_way_rpc_request( - request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket) + request=RPCUProfileRequest.START_PROFILE, + secret_key=self.secret_key, + socket=self.input_socket) async def stop_profile(self) -> None: """Stop profiling the engine""" await self._send_one_way_rpc_request( - request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket) + request=RPCUProfileRequest.STOP_PROFILE, + secret_key=self.secret_key, + socket=self.input_socket) diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index c72207e4ade66..831e32c346e6c 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -1,6 +1,5 @@ import pickle import signal -import hmac from contextlib import contextmanager from typing import Iterator, List, Optional, Union @@ -18,7 +17,8 @@ RPCError, RPCProcessRequest, RPCStartupRequest, RPCStartupResponse, RPCUProfileRequest) -from vllm.engine.multiprocessing.ipc import send +from vllm.engine.multiprocessing.ipc import (send_signed, recv_signed, + check_signed, sign) # yapf: enable from vllm.executor.gpu_executor import GPUExecutor @@ -29,7 +29,7 @@ logger = init_logger(__name__) POLLING_TIMEOUT_MS = 10000 -HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), ) +HEALTHY_RESPONSE = pickle.dumps(VLLM_RPC_SUCCESS_STR) class MQLLMEngine: @@ -63,6 +63,7 @@ class MQLLMEngine: def __init__(self, ipc_path: str, use_async_sockets: bool, + secret_key: bytes, *args, log_requests: bool = True, **kwargs) -> None: @@ -80,6 +81,7 @@ def __init__(self, self._async_socket_engine_callback self.ctx = zmq.Context() # type: ignore[attr-defined] + self.secret_key = secret_key # Receive input from the client. self.input_socket = self.ctx.socket(zmq.constants.PULL) @@ -108,8 +110,10 @@ def dead_error(self) -> BaseException: @classmethod def from_engine_args(cls, engine_args: AsyncEngineArgs, - usage_context: UsageContext, ipc_path: str): + usage_context: UsageContext, ipc_path: str, + secret_key: bytes): """Creates an MQLLMEngine from the engine arguments.""" + # Setup plugins for each process from vllm.plugins import load_general_plugins load_general_plugins() @@ -121,6 +125,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs, return cls(ipc_path=ipc_path, use_async_sockets=use_async_sockets, + secret_key=secret_key, vllm_config=engine_config, executor_class=executor_class, log_requests=not engine_args.disable_log_requests, @@ -164,7 +169,9 @@ def run_startup_loop(self) -> None: with self.make_data_socket() as socket: response: Union[RPCStartupResponse, BaseException] try: - identity, message = socket.recv_multipart(copy=False) + identity, sig, message = socket.recv_multipart(copy=False) + if not check_signed(self.secret_key, sig, message.buffer): + raise ValueError("Message Signature is invalid.") request: RPCStartupRequest = pickle.loads(message.buffer) # Handle the query from the Client. @@ -176,8 +183,9 @@ def run_startup_loop(self) -> None: except Exception as e: response = e - socket.send_multipart((identity, pickle.dumps(response)), - copy=False) + response_bytes = pickle.dumps(response) + sig = sign(self.secret_key, response_bytes) + socket.send_multipart((identity, sig, response_bytes), copy=False) def run_engine_loop(self): """Core busy loop of the LLMEngine.""" @@ -220,15 +228,16 @@ def handle_new_input(self): """Handle new input from the socket""" try: while self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - request = pickle.loads(frames[0].buffer) + message = recv_signed(self.input_socket, self.secret_key) + request = pickle.loads(message) if isinstance(request, RPCProcessRequest): - if len(frames) > 1: - # Use cloudpickle for logits processors - assert isinstance(request.params, SamplingParams) - lprocs = cloudpickle.loads(frames[1].buffer) - request.params.logits_processors = lprocs + # TODO: handle cp case. + # if len(frames) > 1: + # # Use cloudpickle for logits processors + # assert isinstance(request.params, SamplingParams) + # lprocs = cloudpickle.loads(frames[1].buffer) + # request.params.logits_processors = lprocs self._handle_process_request(request) elif isinstance(request, RPCAbortRequest): self._handle_abort_request(request) @@ -313,18 +322,18 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T): pass output_bytes = pickle.dumps(outputs) - send(self.output_socket, output_bytes) + send_signed(self.output_socket, self.secret_key, output_bytes) def _send_healthy(self): """Send HEALTHY message to RPCClient.""" if not self.heartbeat_socket.closed: - send(self.heartbeat_socket, HEALTHY_RESPONSE) + send_signed(self.heartbeat_socket, self.secret_key, HEALTHY_RESPONSE) def _send_unhealthy(self, error: BaseException): """Send UNHEALTHY message to RPCClient.""" if not self.heartbeat_socket.closed: error_bytes = pickle.dumps(error) - send(self.heartbeat_socket, error_bytes) + send_signed(self.heartbeat_socket, self.secret_key, error_bytes) def _async_socket_engine_callback(self, request_outputs: REQUEST_OUTPUTS_T): @@ -355,11 +364,12 @@ def signal_handler(*_) -> None: def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext, - ipc_path: str, engine_alive): + ipc_path: str, secret_key: bytes, engine_alive): try: engine = MQLLMEngine.from_engine_args(engine_args=engine_args, usage_context=usage_context, - ipc_path=ipc_path) + ipc_path=ipc_path, + secret_key=secret_key) signal.signal(signal.SIGTERM, signal_handler) diff --git a/vllm/engine/multiprocessing/ipc.py b/vllm/engine/multiprocessing/ipc.py index e9193ab1040d9..645273af86e88 100644 --- a/vllm/engine/multiprocessing/ipc.py +++ b/vllm/engine/multiprocessing/ipc.py @@ -1,50 +1,55 @@ import hashlib import hmac -import secrets import zmq import zmq.asyncio -# TODO: switch to SECRET_KEY = secrets.token_bytes(16) -# and pass the SECRET_KEY to the background process. -SECRET_KEY = b"my_key" -def sign(msg: bytes) -> bytes: - """Compute the HMAC digest of msg, given signing key `key`""" +def sign(key: bytes, msg: bytes) -> bytes: + """Compute the HMAC digest of msg, given signing key""" + return hmac.HMAC( - SECRET_KEY, + key, msg, digestmod=hashlib.sha256, ).digest() -def check_signed(sig: bytes, msg: bytes) -> bool: - correct_sig = sign(msg) + +def check_signed(key: bytes, sig: bytes, msg: bytes) -> bool: + """Check if signature (HMAC digest) matches.""" + + correct_sig = sign(key, msg) return hmac.compare_digest(sig, correct_sig) -def send_signed(socket: zmq.Socket, msg: bytes): + +def send_signed(socket: zmq.Socket, key: bytes, msg: bytes): """Send signed message to socket.""" - sig = sign(msg) + sig = sign(key, msg) socket.send_multipart((sig, msg), copy=False) -def recv_signed(socket: zmq.Socket): + +def recv_signed(socket: zmq.Socket, key: bytes) -> bytes: """Get signed message from socket.""" sig, msg = socket.recv_multipart(copy=False) - if not check_signed(sig, msg): + if not check_signed(key, sig, msg.buffer): raise ValueError("Message signature is invalid.") - return msg + return msg.buffer + -async def send_signed_async(socket: zmq.asyncio.Socket, msg: bytes): +async def send_signed_async(socket: zmq.asyncio.Socket, key: bytes, + msg: bytes): """Send signed message to asyncio socket.""" - sig = sign(msg) + sig = sign(key, msg) await socket.send_multipart((sig, msg), copy=False) -async def recv_signed_async(socket: zmq.asyncio.Socket): + +async def recv_signed_async(socket: zmq.asyncio.Socket, key: bytes) -> bytes: """Get signed message from asyncio socket.""" sig, msg = await socket.recv_multipart(copy=False) - if not check_signed(sig, msg): + if not check_signed(key, sig, msg.buffer): raise ValueError("Message signature is invalid.") - return msg \ No newline at end of file + return msg.buffer diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 2e27224b41864..274bc0241a62c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -5,6 +5,7 @@ import multiprocessing import os import re +import secrets import signal import socket import tempfile @@ -188,10 +189,13 @@ async def build_async_engine_client_from_engine_args( # not actually result in an exitcode being reported. As a result # we use a shared variable to communicate the information. engine_alive = multiprocessing.Value('b', True, lock=False) + secret_key = secrets.token_bytes(16) engine_process = context.Process(target=run_mp_engine, args=(engine_args, UsageContext.OPENAI_API_SERVER, - ipc_path, engine_alive)) + ipc_path, + secret_key, + engine_alive)) engine_process.start() engine_pid = engine_process.pid assert engine_pid is not None, "Engine process failed to start." @@ -208,7 +212,7 @@ def _cleanup_ipc_path(): # Build RPCClient, which conforms to EngineClient Protocol. engine_config = engine_args.create_engine_config() build_client = partial(MQLLMEngineClient, ipc_path, engine_config, - engine_pid) + engine_pid, secret_key) mq_engine_client = await asyncio.get_running_loop().run_in_executor( None, build_client) try: