Skip to content

Commit

Permalink
e2e working
Browse files Browse the repository at this point in the history
  • Loading branch information
robertgshaw2-neuralmagic committed Dec 12, 2024
1 parent 4934eeb commit 573d228
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 60 deletions.
2 changes: 1 addition & 1 deletion examples/openai_completion_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# Modify OpenAI's API key and API base to use vLLM's API server.
openai_api_key = "EMPTY"
openai_api_base = "http://localhost:8001/v1"
openai_api_base = "http://localhost:8002/v1"

client = OpenAI(
# defaults to os.environ.get("OPENAI_API_KEY")
Expand Down
1 change: 0 additions & 1 deletion vllm/engine/multiprocessing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

VLLM_RPC_SUCCESS_STR = "SUCCESS"


IPC_INPUT_EXT = "_input_socket"
IPC_OUTPUT_EXT = "_output_socket"
IPC_HEALTH_EXT = "_health_socket"
Expand Down
49 changes: 31 additions & 18 deletions vllm/engine/multiprocessing/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.multiprocessing.ipc import (send_signed_async,
from vllm.engine.multiprocessing.ipc import (send_signed_async,
recv_signed_async)

from vllm.engine.protocol import EngineClient
Expand Down Expand Up @@ -84,8 +84,9 @@ class MQLLMEngineClient(EngineClient):
"""

def __init__(self, ipc_path: str, engine_config: VllmConfig,
engine_pid: int):
engine_pid: int, secret_key: bytes):
self.context = zmq.asyncio.Context()
self.secret_key = secret_key
self._errored_with: Optional[BaseException] = None

# Get the configs.
Expand Down Expand Up @@ -163,6 +164,7 @@ async def run_heartbeat_loop(self, timeout: int):
# Heartbeat received- check the message
await self._check_success(
error_message="Heartbeat failed.",
secret_key=self.secret_key,
socket=self.heartbeat_socket)

logger.debug("Heartbeat successful.")
Expand Down Expand Up @@ -195,7 +197,8 @@ async def run_output_handler_loop(self):
ENGINE_DEAD_ERROR(self._errored_with))
return

message = await recv_signed_async(self.output_socket)
message = await recv_signed_async(self.output_socket,
self.secret_key)
request_outputs = pickle.loads(message)

is_error = isinstance(request_outputs,
Expand Down Expand Up @@ -290,20 +293,21 @@ def _set_errored(self, e: BaseException):
async def _send_get_data_rpc_request(request: RPCStartupRequest,
expected_type: Any,
error_message: str,
socket: Socket) -> Any:
socket: Socket,
secret_key: bytes) -> Any:
"""Send an RPC request that is expecting data back."""

# Ping RPCServer with a request.
await send_signed_async(socket, pickle.dumps(request))
await send_signed_async(socket, secret_key, pickle.dumps(request))

# Make sure the server responds in time.
if await socket.poll(timeout=VLLM_RPC_TIMEOUT) == 0:
raise TimeoutError("RPCServer didn't reply within "
f"{VLLM_RPC_TIMEOUT} ms")

# Await the data from the Server.
frame = await recv_signed_async(socket)
data = pickle.loads(frame.buffer)
message = await recv_signed_async(socket, secret_key)
data = pickle.loads(message)

if isinstance(data, BaseException):
raise data
Expand All @@ -314,13 +318,14 @@ async def _send_get_data_rpc_request(request: RPCStartupRequest,

@staticmethod
async def _send_one_way_rpc_request(request: RPC_REQUEST_T,
socket: Socket):
socket: Socket,
secret_key: bytes):
"""Send one-way RPC request to trigger an action."""

if socket.closed:
raise MQClientClosedError()

await send_signed_async(socket, pickle.dumps(request))
await send_signed_async(socket, secret_key, pickle.dumps(request))

async def _await_ack(self, error_message: str, socket: Socket):
"""Await acknowledgement that a request succeeded."""
Expand All @@ -332,17 +337,18 @@ async def _await_ack(self, error_message: str, socket: Socket):
raise TimeoutError("MQLLMEngine didn't reply within "
f"{VLLM_RPC_TIMEOUT}ms")

await self._check_success(error_message, socket)
await self._check_success(error_message, self.secret_key, socket)

@staticmethod
async def _check_success(error_message: str, socket: Socket):
async def _check_success(error_message: str, secret_key: bytes,
socket: Socket):
"""Confirm that socket has a VLLM_RPC_SUCCESS_STR message"""

if socket.closed:
raise MQClientClosedError()

frame = await recv_signed_async(socket, error_message)
response = pickle.loads(frame.buffer)
message = await recv_signed_async(socket, secret_key)
response = pickle.loads(message)

# Raise error if unsuccessful
if isinstance(response, BaseException):
Expand Down Expand Up @@ -373,6 +379,7 @@ async def _wait_for_server_rpc(self, socket: Socket) -> RPCStartupResponse:
request=RPCStartupRequest.IS_SERVER_READY,
expected_type=RPCStartupResponse,
error_message="Unable to start RPC Server",
secret_key=self.secret_key,
socket=socket)

async def abort(self, request_id: str):
Expand Down Expand Up @@ -627,9 +634,11 @@ async def _process_request(
))

# 3) Send the RPCGenerateRequest to the MQLLMEngine.
parts = (request_bytes,
lp_bytes) if lp_bytes else (request_bytes, )
await send_signed_async(self.input_socket, parts)

# parts = (request_bytes,
# lp_bytes) if lp_bytes else (request_bytes, )
await send_signed_async(self.input_socket, self.secret_key,
request_bytes)

# 4) Stream the RequestOutputs from the output queue. Note
# that the output_loop pushes RequestOutput objects to this
Expand All @@ -655,10 +664,14 @@ async def start_profile(self) -> None:
"""Start profiling the engine"""

await self._send_one_way_rpc_request(
request=RPCUProfileRequest.START_PROFILE, socket=self.input_socket)
request=RPCUProfileRequest.START_PROFILE,
secret_key=self.secret_key,
socket=self.input_socket)

async def stop_profile(self) -> None:
"""Stop profiling the engine"""

await self._send_one_way_rpc_request(
request=RPCUProfileRequest.STOP_PROFILE, socket=self.input_socket)
request=RPCUProfileRequest.STOP_PROFILE,
secret_key=self.secret_key,
socket=self.input_socket)
48 changes: 29 additions & 19 deletions vllm/engine/multiprocessing/engine.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import pickle
import signal
import hmac
from contextlib import contextmanager
from typing import Iterator, List, Optional, Union

Expand All @@ -18,7 +17,8 @@
RPCError, RPCProcessRequest,
RPCStartupRequest, RPCStartupResponse,
RPCUProfileRequest)
from vllm.engine.multiprocessing.ipc import send
from vllm.engine.multiprocessing.ipc import (send_signed, recv_signed,
check_signed, sign)

# yapf: enable
from vllm.executor.gpu_executor import GPUExecutor
Expand All @@ -29,7 +29,7 @@
logger = init_logger(__name__)

POLLING_TIMEOUT_MS = 10000
HEALTHY_RESPONSE = (pickle.dumps(VLLM_RPC_SUCCESS_STR), )
HEALTHY_RESPONSE = pickle.dumps(VLLM_RPC_SUCCESS_STR)


class MQLLMEngine:
Expand Down Expand Up @@ -63,6 +63,7 @@ class MQLLMEngine:
def __init__(self,
ipc_path: str,
use_async_sockets: bool,
secret_key: bytes,
*args,
log_requests: bool = True,
**kwargs) -> None:
Expand All @@ -80,6 +81,7 @@ def __init__(self,
self._async_socket_engine_callback

self.ctx = zmq.Context() # type: ignore[attr-defined]
self.secret_key = secret_key

# Receive input from the client.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
Expand Down Expand Up @@ -108,8 +110,10 @@ def dead_error(self) -> BaseException:

@classmethod
def from_engine_args(cls, engine_args: AsyncEngineArgs,
usage_context: UsageContext, ipc_path: str):
usage_context: UsageContext, ipc_path: str,
secret_key: bytes):
"""Creates an MQLLMEngine from the engine arguments."""

# Setup plugins for each process
from vllm.plugins import load_general_plugins
load_general_plugins()
Expand All @@ -121,6 +125,7 @@ def from_engine_args(cls, engine_args: AsyncEngineArgs,

return cls(ipc_path=ipc_path,
use_async_sockets=use_async_sockets,
secret_key=secret_key,
vllm_config=engine_config,
executor_class=executor_class,
log_requests=not engine_args.disable_log_requests,
Expand Down Expand Up @@ -164,7 +169,9 @@ def run_startup_loop(self) -> None:
with self.make_data_socket() as socket:
response: Union[RPCStartupResponse, BaseException]
try:
identity, message = socket.recv_multipart(copy=False)
identity, sig, message = socket.recv_multipart(copy=False)
if not check_signed(self.secret_key, sig, message.buffer):
raise ValueError("Message Signature is invalid.")
request: RPCStartupRequest = pickle.loads(message.buffer)

# Handle the query from the Client.
Expand All @@ -176,8 +183,9 @@ def run_startup_loop(self) -> None:
except Exception as e:
response = e

socket.send_multipart((identity, pickle.dumps(response)),
copy=False)
response_bytes = pickle.dumps(response)
sig = sign(self.secret_key, response_bytes)
socket.send_multipart((identity, sig, response_bytes), copy=False)

def run_engine_loop(self):
"""Core busy loop of the LLMEngine."""
Expand Down Expand Up @@ -220,15 +228,16 @@ def handle_new_input(self):
"""Handle new input from the socket"""
try:
while self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
request = pickle.loads(frames[0].buffer)
message = recv_signed(self.input_socket, self.secret_key)
request = pickle.loads(message)

if isinstance(request, RPCProcessRequest):
if len(frames) > 1:
# Use cloudpickle for logits processors
assert isinstance(request.params, SamplingParams)
lprocs = cloudpickle.loads(frames[1].buffer)
request.params.logits_processors = lprocs
# TODO: handle cp case.
# if len(frames) > 1:
# # Use cloudpickle for logits processors
# assert isinstance(request.params, SamplingParams)
# lprocs = cloudpickle.loads(frames[1].buffer)
# request.params.logits_processors = lprocs
self._handle_process_request(request)
elif isinstance(request, RPCAbortRequest):
self._handle_abort_request(request)
Expand Down Expand Up @@ -313,18 +322,18 @@ def _send_outputs(self, outputs: REQUEST_OUTPUTS_T):
pass

output_bytes = pickle.dumps(outputs)
send(self.output_socket, output_bytes)
send_signed(self.output_socket, self.secret_key, output_bytes)

def _send_healthy(self):
"""Send HEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
send(self.heartbeat_socket, HEALTHY_RESPONSE)
send_signed(self.heartbeat_socket, self.secret_key, HEALTHY_RESPONSE)

Check failure on line 330 in vllm/engine/multiprocessing/engine.py

View workflow job for this annotation

GitHub Actions / ruff (3.12)

Ruff (E501)

vllm/engine/multiprocessing/engine.py:330:81: E501 Line too long (81 > 80)

def _send_unhealthy(self, error: BaseException):
"""Send UNHEALTHY message to RPCClient."""
if not self.heartbeat_socket.closed:
error_bytes = pickle.dumps(error)
send(self.heartbeat_socket, error_bytes)
send_signed(self.heartbeat_socket, self.secret_key, error_bytes)

def _async_socket_engine_callback(self,
request_outputs: REQUEST_OUTPUTS_T):
Expand Down Expand Up @@ -355,11 +364,12 @@ def signal_handler(*_) -> None:


def run_mp_engine(engine_args: AsyncEngineArgs, usage_context: UsageContext,
ipc_path: str, engine_alive):
ipc_path: str, secret_key: bytes, engine_alive):
try:
engine = MQLLMEngine.from_engine_args(engine_args=engine_args,
usage_context=usage_context,
ipc_path=ipc_path)
ipc_path=ipc_path,
secret_key=secret_key)

signal.signal(signal.SIGTERM, signal_handler)

Expand Down
43 changes: 24 additions & 19 deletions vllm/engine/multiprocessing/ipc.py
Original file line number Diff line number Diff line change
@@ -1,50 +1,55 @@
import hashlib
import hmac
import secrets

import zmq
import zmq.asyncio

# TODO: switch to SECRET_KEY = secrets.token_bytes(16)
# and pass the SECRET_KEY to the background process.
SECRET_KEY = b"my_key"

def sign(msg: bytes) -> bytes:
"""Compute the HMAC digest of msg, given signing key `key`"""
def sign(key: bytes, msg: bytes) -> bytes:
"""Compute the HMAC digest of msg, given signing key"""

return hmac.HMAC(
SECRET_KEY,
key,
msg,
digestmod=hashlib.sha256,
).digest()

def check_signed(sig: bytes, msg: bytes) -> bool:
correct_sig = sign(msg)

def check_signed(key: bytes, sig: bytes, msg: bytes) -> bool:
"""Check if signature (HMAC digest) matches."""

correct_sig = sign(key, msg)
return hmac.compare_digest(sig, correct_sig)

def send_signed(socket: zmq.Socket, msg: bytes):

def send_signed(socket: zmq.Socket, key: bytes, msg: bytes):
"""Send signed message to socket."""

sig = sign(msg)
sig = sign(key, msg)
socket.send_multipart((sig, msg), copy=False)

def recv_signed(socket: zmq.Socket):

def recv_signed(socket: zmq.Socket, key: bytes) -> bytes:
"""Get signed message from socket."""

sig, msg = socket.recv_multipart(copy=False)
if not check_signed(sig, msg):
if not check_signed(key, sig, msg.buffer):
raise ValueError("Message signature is invalid.")
return msg
return msg.buffer


async def send_signed_async(socket: zmq.asyncio.Socket, msg: bytes):
async def send_signed_async(socket: zmq.asyncio.Socket, key: bytes,
msg: bytes):
"""Send signed message to asyncio socket."""

sig = sign(msg)
sig = sign(key, msg)
await socket.send_multipart((sig, msg), copy=False)

async def recv_signed_async(socket: zmq.asyncio.Socket):

async def recv_signed_async(socket: zmq.asyncio.Socket, key: bytes) -> bytes:
"""Get signed message from asyncio socket."""

sig, msg = await socket.recv_multipart(copy=False)
if not check_signed(sig, msg):
if not check_signed(key, sig, msg.buffer):
raise ValueError("Message signature is invalid.")
return msg
return msg.buffer
Loading

0 comments on commit 573d228

Please sign in to comment.