Skip to content

Commit

Permalink
Merge pull request #25 from njhill/overlap_io
Browse files Browse the repository at this point in the history
Overlap io
  • Loading branch information
robertgshaw2-neuralmagic authored Nov 4, 2024
2 parents bb1a75b + a904bad commit e3014e2
Showing 1 changed file with 35 additions and 18 deletions.
53 changes: 35 additions & 18 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import multiprocessing
import queue
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import List, Tuple, Type

import msgspec
Expand Down Expand Up @@ -155,6 +157,9 @@ def __init__(

self.ctx = zmq.Context() # type: ignore[attr-defined]

self.input_queue = queue.Queue()
self.output_queue = queue.Queue()

# Get EngineCoreRequests from the LLMEngine.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.connect(input_path)
Expand All @@ -163,6 +168,9 @@ def __init__(
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(output_path)

Thread(target=self.process_input_socket, daemon=True).start()
Thread(target=self.process_output_socket, daemon=True).start()

# Send Readiness signal to LLMEngine.
ready_socket = None
try:
Expand All @@ -173,6 +181,21 @@ def __init__(
if ready_socket:
ready_socket.close(linger=0)

def process_input_socket(self):
while True:
frames = self.input_socket.recv_multipart(copy=False)
request = self.msgpack_decoder.decode(frames[0].buffer)
self.input_queue.put_nowait(request)

def process_output_socket(self):
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs_serialized = self.msgpack_encoder.encode(outputs)
self.output_socket.send_multipart((outputs_serialized, ),
copy=False,
flags=zmq.NOBLOCK)

@staticmethod
def wait_for_startup(
proc: BaseProcess,
Expand Down Expand Up @@ -244,8 +267,8 @@ def run_busy_loop(self):
while True:
# Poll the input socket until there is work to do.
if not self.scheduler.has_unfinished_requests():
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for new requests.")
request = self.input_queue.get()
self._handle_request(request)

# Handle new input from the socket.
self._handle_new_input()
Expand All @@ -258,17 +281,17 @@ def run_busy_loop(self):

def _handle_new_input(self):
"""Handle new input from the AsyncLLMEngine for async mode."""
while not self.input_queue.empty():
request = self.input_queue.get_nowait()
self._handle_request(request)

def _handle_request(self, request: EngineCoreRequest):
try:
if self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
engine_core_request = self.msgpack_decoder.decode(
frames[0].buffer)
self.add_request(engine_core_request)
self.add_request(request)

# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling
# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling

except Exception as e:
# TODO: handle gracefully
Expand All @@ -278,11 +301,5 @@ def _send_outputs(self,
engine_core_outputs: List[EngineCoreOutput]) -> None:
"""Serialize and send output to the AsyncLLMEngine for async mode."""

if not engine_core_outputs:
return

outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs_serialized = self.msgpack_encoder.encode(outputs)
self.output_socket.send_multipart((outputs_serialized, ),
copy=False,
flags=zmq.NOBLOCK)
if engine_core_outputs:
self.output_queue.put_nowait(engine_core_outputs)

0 comments on commit e3014e2

Please sign in to comment.