Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overlap io #25

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 35 additions & 18 deletions vllm/v1/engine/core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import multiprocessing
import queue
from multiprocessing.process import BaseProcess
from threading import Thread
from typing import List, Tuple, Type

import msgspec
Expand Down Expand Up @@ -155,6 +157,9 @@ def __init__(

self.ctx = zmq.Context() # type: ignore[attr-defined]

self.input_queue = queue.Queue()
self.output_queue = queue.Queue()

# Get EngineCoreRequests from the LLMEngine.
self.input_socket = self.ctx.socket(zmq.constants.PULL)
self.input_socket.connect(input_path)
Expand All @@ -163,6 +168,9 @@ def __init__(
self.output_socket = self.ctx.socket(zmq.constants.PUSH)
self.output_socket.bind(output_path)

Thread(target=self.process_input_socket, daemon=True).start()
Thread(target=self.process_output_socket, daemon=True).start()

# Send Readiness signal to LLMEngine.
ready_socket = None
try:
Expand All @@ -173,6 +181,21 @@ def __init__(
if ready_socket:
ready_socket.close(linger=0)

def process_input_socket(self):
while True:
frames = self.input_socket.recv_multipart(copy=False)
request = self.msgpack_decoder.decode(frames[0].buffer)
self.input_queue.put_nowait(request)

def process_output_socket(self):
while True:
engine_core_outputs = self.output_queue.get()
outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs_serialized = self.msgpack_encoder.encode(outputs)
self.output_socket.send_multipart((outputs_serialized, ),
copy=False,
flags=zmq.NOBLOCK)

@staticmethod
def wait_for_startup(
proc: BaseProcess,
Expand Down Expand Up @@ -244,8 +267,8 @@ def run_busy_loop(self):
while True:
# Poll the input socket until there is work to do.
if not self.scheduler.has_unfinished_requests():
while self.input_socket.poll(timeout=POLLING_TIMEOUT_MS) == 0:
logger.debug("Waiting for new requests.")
request = self.input_queue.get()
self._handle_request(request)

# Handle new input from the socket.
self._handle_new_input()
Expand All @@ -258,17 +281,17 @@ def run_busy_loop(self):

def _handle_new_input(self):
"""Handle new input from the AsyncLLMEngine for async mode."""
while not self.input_queue.empty():
request = self.input_queue.get_nowait()
self._handle_request(request)

def _handle_request(self, request: EngineCoreRequest):
try:
if self.input_socket.poll(timeout=0) != 0:
frames = self.input_socket.recv_multipart(copy=False)
engine_core_request = self.msgpack_decoder.decode(
frames[0].buffer)
self.add_request(engine_core_request)
self.add_request(request)

# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling
# TODO: handle abort via another socket
# TODO: handle logits processors via cloudpickle
# TODO: handle profiling

except Exception as e:
# TODO: handle gracefully
Expand All @@ -278,11 +301,5 @@ def _send_outputs(self,
engine_core_outputs: List[EngineCoreOutput]) -> None:
"""Serialize and send output to the AsyncLLMEngine for async mode."""

if not engine_core_outputs:
return

outputs = EngineCoreOutputs(outputs=engine_core_outputs)
outputs_serialized = self.msgpack_encoder.encode(outputs)
self.output_socket.send_multipart((outputs_serialized, ),
copy=False,
flags=zmq.NOBLOCK)
if engine_core_outputs:
self.output_queue.put_nowait(engine_core_outputs)