From 981d3dbc2d786b27031426cc832cf61a2a50d3f6 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 2 Nov 2024 16:12:05 -0700 Subject: [PATCH 1/2] don't do zmq i/o on critical path --- vllm/v1/engine/core.py | 52 +++++++++++++++++++++++++++++------------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f5378b7a25a11..5e3df1c6e8a1d 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,5 +1,8 @@ import multiprocessing +import queue +from collections.abc import Buffer from multiprocessing.process import BaseProcess +from threading import Thread from typing import List, Tuple, Type import msgspec @@ -155,6 +158,9 @@ def __init__( self.ctx = zmq.Context() # type: ignore[attr-defined] + self.input_queue: queue.Queue[Buffer] = queue.Queue() + self.output_queue: queue.Queue[Buffer] = queue.Queue() + # Get EngineCoreRequests from the LLMEngine. self.input_socket = self.ctx.socket(zmq.constants.PULL) self.input_socket.connect(input_path) @@ -163,6 +169,9 @@ def __init__( self.output_socket = self.ctx.socket(zmq.constants.PUSH) self.output_socket.bind(output_path) + Thread(target=self.process_input_socket, daemon=True).start() + Thread(target=self.process_output_socket, daemon=True).start() + # Send Readiness signal to LLMEngine. ready_socket = None try: @@ -173,6 +182,18 @@ def __init__( if ready_socket: ready_socket.close(linger=0) + def process_input_socket(self): + while True: + frames = self.input_socket.recv_multipart(copy=False) + self.input_queue.put_nowait(frames[0].buffer) + + def process_output_socket(self): + while True: + serialized = self.output_queue.get() + self.output_socket.send_multipart((serialized, ), + copy=False, + flags=zmq.NOBLOCK) + @staticmethod def wait_for_startup( proc: BaseProcess, @@ -244,8 +265,8 @@ def run_busy_loop(self): while True: # Poll the input socket until there is work to do. if not self.scheduler.has_unfinished_requests(): - while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0: - logger.debug("Waiting for new requests.") + buffer = self.input_queue.get() + self._handle_input_buffer(buffer) # Handle new input from the socket. self._handle_new_input() @@ -256,24 +277,25 @@ def run_busy_loop(self): # Send outputs to the EngineCoreClient. self._send_outputs(outputs) - def _handle_new_input(self): - """Handle new input from the AsyncLLMEngine for async mode.""" - + def _handle_input_buffer(self, buffer): try: - if self.input_socket.poll(timeout=0) != 0: - frames = self.input_socket.recv_multipart(copy=False) - engine_core_request = self.msgpack_decoder.decode( - frames[0].buffer) - self.add_request(engine_core_request) + engine_core_request = self.msgpack_decoder.decode(buffer) + self.add_request(engine_core_request) - # TODO: handle abort via another socket - # TODO: handle logits processors via cloudpickle - # TODO: handle profiling + # TODO: handle abort via another socket + # TODO: handle logits processors via cloudpickle + # TODO: handle profiling except Exception as e: # TODO: handle gracefully raise e + def _handle_new_input(self): + """Handle new input from the AsyncLLMEngine for async mode.""" + while not self.input_queue.empty(): + buffer = self.input_queue.get_nowait() + self._handle_input_buffer(buffer) + def _send_outputs(self, engine_core_outputs: List[EngineCoreOutput]) -> None: """Serialize and send output to the AsyncLLMEngine for async mode.""" @@ -283,6 +305,4 @@ def _send_outputs(self, outputs = EngineCoreOutputs(outputs=engine_core_outputs) outputs_serialized = self.msgpack_encoder.encode(outputs) - self.output_socket.send_multipart((outputs_serialized, ), - copy=False, - flags=zmq.NOBLOCK) + self.output_queue.put_nowait(outputs_serialized) From a904bad1d77e783687eb7484a8042d433fe4302a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 2 Nov 2024 16:26:42 -0700 Subject: [PATCH 2/2] also move ser/deser into separate threads --- vllm/v1/engine/core.py | 43 ++++++++++++++++++++---------------------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 5e3df1c6e8a1d..fa983fc39ce76 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,6 +1,5 @@ import multiprocessing import queue -from collections.abc import Buffer from multiprocessing.process import BaseProcess from threading import Thread from typing import List, Tuple, Type @@ -158,8 +157,8 @@ def __init__( self.ctx = zmq.Context() # type: ignore[attr-defined] - self.input_queue: queue.Queue[Buffer] = queue.Queue() - self.output_queue: queue.Queue[Buffer] = queue.Queue() + self.input_queue = queue.Queue() + self.output_queue = queue.Queue() # Get EngineCoreRequests from the LLMEngine. self.input_socket = self.ctx.socket(zmq.constants.PULL) @@ -185,12 +184,15 @@ def __init__( def process_input_socket(self): while True: frames = self.input_socket.recv_multipart(copy=False) - self.input_queue.put_nowait(frames[0].buffer) + request = self.msgpack_decoder.decode(frames[0].buffer) + self.input_queue.put_nowait(request) def process_output_socket(self): while True: - serialized = self.output_queue.get() - self.output_socket.send_multipart((serialized, ), + engine_core_outputs = self.output_queue.get() + outputs = EngineCoreOutputs(outputs=engine_core_outputs) + outputs_serialized = self.msgpack_encoder.encode(outputs) + self.output_socket.send_multipart((outputs_serialized, ), copy=False, flags=zmq.NOBLOCK) @@ -265,8 +267,8 @@ def run_busy_loop(self): while True: # Poll the input socket until there is work to do. if not self.scheduler.has_unfinished_requests(): - buffer = self.input_queue.get() - self._handle_input_buffer(buffer) + request = self.input_queue.get() + self._handle_request(request) # Handle new input from the socket. self._handle_new_input() @@ -277,10 +279,15 @@ def run_busy_loop(self): # Send outputs to the EngineCoreClient. self._send_outputs(outputs) - def _handle_input_buffer(self, buffer): + def _handle_new_input(self): + """Handle new input from the AsyncLLMEngine for async mode.""" + while not self.input_queue.empty(): + request = self.input_queue.get_nowait() + self._handle_request(request) + + def _handle_request(self, request: EngineCoreRequest): try: - engine_core_request = self.msgpack_decoder.decode(buffer) - self.add_request(engine_core_request) + self.add_request(request) # TODO: handle abort via another socket # TODO: handle logits processors via cloudpickle @@ -290,19 +297,9 @@ def _handle_input_buffer(self, buffer): # TODO: handle gracefully raise e - def _handle_new_input(self): - """Handle new input from the AsyncLLMEngine for async mode.""" - while not self.input_queue.empty(): - buffer = self.input_queue.get_nowait() - self._handle_input_buffer(buffer) - def _send_outputs(self, engine_core_outputs: List[EngineCoreOutput]) -> None: """Serialize and send output to the AsyncLLMEngine for async mode.""" - if not engine_core_outputs: - return - - outputs = EngineCoreOutputs(outputs=engine_core_outputs) - outputs_serialized = self.msgpack_encoder.encode(outputs) - self.output_queue.put_nowait(outputs_serialized) + if engine_core_outputs: + self.output_queue.put_nowait(engine_core_outputs)