diff --git a/examples/infinitestore_pd_separate.sh b/examples/infinitestore_pd_separate.sh new file mode 100755 index 0000000000000..334186bc7ef18 --- /dev/null +++ b/examples/infinitestore_pd_separate.sh @@ -0,0 +1,153 @@ +#!/bin/bash +set -euo pipefail + +# ========================= +# Configuration Parameters +# ========================= + +PREFILL_PORT=8100 +DECODE_PORT=8200 + +MODEL="facebook/opt-125m" + +PREFILL_LOG="/tmp/prefill.log" +DECODE_LOG="/tmp/decode.log" + +START_TIMEOUT=120 +WAIT_INTERVAL=1 + +PROMPT="San Francisco is a" + +PORTS=($PREFILL_PORT $DECODE_PORT) +LOGS=($PREFILL_LOG $DECODE_LOG) +STAGES=("prefill" "decode") +GPUS=(0 1) + +# ========================= +# Function Definitions +# ========================= + +# Function to check if a command exists +command_exists() { + command -v "$1" &>/dev/null +} + +# Function to log messages with timestamps +log() { + local message="$1" + echo "$(date '+%Y-%m-%d %H:%M:%S') $message" +} + +# Function to check if a port is in use +check_port() { + local port=$1 + if lsof -i :"$port" -t &>/dev/null; then + log "Error: Port $port is in use." + exit 1 + fi +} + +# Function to start a vllm server +start_vllm_server() { + local gpu_id=$1 + local stage=$2 + local port=$3 + local log_file=$4 + CUDA_VISIBLE_DEVICES="$gpu_id" PD_SEPARATE_STAGE="$stage" \ + vllm serve "$MODEL" --enforce-eager --port "$port" --dtype=float16 > "$log_file" 2>&1 & +} + +# Function to wait for a vllm endpoint to become ready +wait_for_endpoint() { + local port=$1 + local elapsed=0 + while true; do + if curl --output /dev/null --silent --fail "http://localhost:$port/v1/models"; then + log "vllm on port $port is ready!" + break + fi + if [ $elapsed -ge $START_TIMEOUT ]; then + log "Error: vllm on port $port is not ready after $START_TIMEOUT seconds." + log "Check log file for more details." + exit 1 + fi + sleep $WAIT_INTERVAL + elapsed=$((elapsed + WAIT_INTERVAL)) + done +} + +# Function to clean up background processes on exit +cleanup() { + log "Cleaning up background processes..." + pkill -f "vllm serve" || true +} + +trap cleanup EXIT + +# ========================= +# Main Script Execution +# ========================= + +# Check for required commands +for cmd in vllm curl lsof nvidia-smi; do + if ! command_exists "$cmd"; then + log "Error: Required command '$cmd' is not installed." + exit 1 + fi +done + +# Check if INFINITY is supported +OUTPUT=$(python3 -c "from infinity import check_infinity_supported; \ +result = check_infinity_supported(); \ +print(result)" 2>&1) +EXIT_CODE=$? + +if [ $EXIT_CODE -ne 0 ]; then + echo "Error: Infinity is not supported: $OUTPUT" + exit $EXIT_CODE +fi + +# Check if there are at least 2 GPUs +GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) +if [ "$GPU_COUNT" -lt 2 ]; then + log "Error: Less than 2 GPUs detected." + exit 1 +fi + +# Check if the ports are not in use +for port in "${PORTS[@]}"; do + check_port "$port" +done + +# Start vllm servers +for i in "${!PORTS[@]}"; do + log "Starting vllm server (${STAGES[$i]}) on port ${PORTS[$i]}..." + start_vllm_server "${GPUS[$i]}" "${STAGES[$i]}" "${PORTS[$i]}" "${LOGS[$i]}" +done + +# Wait for vllm endpoints to become ready +for port in "${PORTS[@]}"; do + wait_for_endpoint "$port" +done + +log "All vllm endpoints are ready!" + +# Prepare JSON data +DATA=$(jq -n \ + --arg model "$MODEL" \ + --arg prompt "$PROMPT" \ + '{model: $model, prompt: $prompt}') + +log "Sending request to prefill and decode..." + +# Send requests +prefill_output=$(curl -s "http://localhost:${PREFILL_PORT}/v1/completions" \ + -H "Content-Type: application/json" \ + -d "$DATA") + +decode_output=$(curl -s "http://localhost:${DECODE_PORT}/v1/completions" \ + -H "Content-Type: application/json" \ + -d "$DATA") + +# Display outputs +printf "Prefill output:\n%s\n\nDecode output:\n%s\n" "$prefill_output" "$decode_output" diff --git a/examples/multi_host_infinitestore_pd_separate.sh b/examples/multi_host_infinitestore_pd_separate.sh new file mode 100755 index 0000000000000..ff38008cb7c49 --- /dev/null +++ b/examples/multi_host_infinitestore_pd_separate.sh @@ -0,0 +1,224 @@ +#!/bin/bash +set -euo pipefail + +# ========================= +# Configuration Parameters +# ========================= + +# Replace these with the actual IP addresses of your hosts +PREFILL_HOST="10.192.18.145" +DECODE_HOST="10.192.24.218" + +INFINITY_HOST=10.192.18.145 + +PORT=8000 + +MODEL="facebook/opt-125m" + +PREFILL_LOG="/tmp/prefill.log" +DECODE_LOG="/tmp/decode.log" + +START_TIMEOUT=120 +WAIT_INTERVAL=1 + +PROMPT="San Francisco is a" + +STAGES=("prefill" "decode") +HOSTS=("$PREFILL_HOST" "$DECODE_HOST") +GPUS=(0 0) +LOGS=("$PREFILL_LOG" "$DECODE_LOG") + +# Conda environments for each host +PREFILL_CONDA_ENV="qian2" +DECODE_CONDA_ENV="qian" +CONDA_ENVS=("$PREFILL_CONDA_ENV" "$DECODE_CONDA_ENV") + + +# ========================= +# Function Definitions +# ========================= + +# Function to check if a host is the local machine +is_local_host() { + local host_ip="$1" + local local_ips + local_ips=$(hostname -I) + if [[ "$host_ip" == "127.0.0.1" || "$host_ip" == "localhost" ]]; then + return 0 + fi + for ip in $local_ips; do + if [[ "$host_ip" == "$ip" ]]; then + return 0 + fi + done + return 1 +} + +# Function to check if a command exists on a host +command_exists_on_host() { + local host="$1" + local conda_env="$2" + local cmd="$3" + if is_local_host "$host"; then + source ~/.bashrc + conda activate "$conda_env" + command -v "$cmd" &>/dev/null + else + ssh "$host" "bash -c 'source ~/.bashrc; conda activate $conda_env; command -v $cmd &>/dev/null'" + fi +} + +# Function to log messages with timestamps +log() { + local message="$1" + echo "$(date '+%Y-%m-%d %H:%M:%S') $message" +} + +# Function to start a vllm server on a host +start_vllm_server_on_host() { + local host="$1" + local conda_env="$2" + local gpu_id="$3" + local stage="$4" + local port="$5" + local log_file="$6" + if is_local_host "$host"; then + source ~/.bashrc + conda activate "$conda_env" + CUDA_VISIBLE_DEVICES="$gpu_id" PD_SEPARATE_STAGE="$stage" INFINITE_STORE_SERVER=\"$INFINITY_HOST\" \ + vllm serve "$MODEL" --enforce-eager --port "$port" --dtype=float16 > "$log_file" 2>&1 & + else + ssh "$host" "bash -c 'source ~/.bashrc; conda activate $conda_env; \ + CUDA_VISIBLE_DEVICES=\"$gpu_id\" PD_SEPARATE_STAGE=\"$stage\" INFINITE_STORE_SERVER=\"$INFINITY_HOST\" \ + vllm serve \"$MODEL\" --enforce-eager --port \"$port\" --dtype=float16 > \"$log_file\" 2>&1 &'" + fi +} + +# Function to wait for a vllm endpoint to become ready on a host +wait_for_endpoint() { + local host="$1" + local port="$2" + local elapsed=0 + while true; do + if curl --output /dev/null --silent --fail "http://$host:$port/v1/models"; then + log "vllm on $host:$port is ready!" + break + fi + if [ $elapsed -ge $START_TIMEOUT ]; then + log "Error: vllm on $host:$port is not ready after $START_TIMEOUT seconds." + log "Check log file on the host for more details." + exit 1 + fi + sleep $WAIT_INTERVAL + elapsed=$((elapsed + WAIT_INTERVAL)) + done +} + +# Function to clean up background processes on hosts +cleanup() { + log "Cleaning up background processes..." + for i in "${!HOSTS[@]}"; do + host="${HOSTS[$i]}" + conda_env="${CONDA_ENVS[$i]}" + if is_local_host "$host"; then + pkill -f 'vllm serve' || true + else + ssh "$host" "pkill -f 'vllm serve' || true" + fi + done +} + +trap cleanup EXIT + +# ========================= +# Main Script Execution +# ========================= + +echo aaaaa +# Check for required commands on hosts +for i in "${!HOSTS[@]}"; do + host="${HOSTS[$i]}" + conda_env="${CONDA_ENVS[$i]}" + for cmd in vllm curl nvidia-smi; do + if ! command_exists_on_host "$host" "$conda_env" "$cmd"; then + log "Error: Required command '$cmd' is not installed on host $host in conda environment '$conda_env'." + exit 1 + fi + done +done +echo aaaaa1 +# Check if Infinity is supported on hosts +for i in "${!HOSTS[@]}"; do + host="${HOSTS[$i]}" + conda_env="${CONDA_ENVS[$i]}" + if is_local_host "$host"; then + source ~/.bashrc + conda activate "$conda_env" + OUTPUT=$(python3 -c 'from infinistore import check_supported; result = check_supported(); print(result)' 2>&1) + EXIT_CODE=$? + else + OUTPUT=$(ssh "$host" "bash -c 'source ~/.bashrc; conda activate $conda_env; python3 -c \"from infinistore import check_supported; result = check_supported(); print(result)\"' 2>&1") + + + EXIT_CODE=$? + fi + + echo $host: $OUTPUT $EXIT_CODE + if [ $EXIT_CODE -ne 0 ]; then + log "Error: Infinity is not supported on host $host: $OUTPUT" + exit $EXIT_CODE + fi +done + +echo aaaaa2 +# Check if there is at least 1 GPU on each host +for i in "${!HOSTS[@]}"; do + host="${HOSTS[$i]}" + conda_env="${CONDA_ENVS[$i]}" + if is_local_host "$host"; then + source ~/.bashrc + conda activate "$conda_env" + GPU_COUNT=$(nvidia-smi --query-gpu=name --format=csv,noheader | wc -l) + else + GPU_COUNT=$(ssh "$host" "bash -c 'source ~/.bashrc; conda activate $conda_env; nvidia-smi --query-gpu=name --format=csv,noheader | wc -l'") + fi + if [ "$GPU_COUNT" -lt 1 ]; then + log "Error: No GPUs detected on host $host." + exit 1 + fi +done + +# Start vllm servers on hosts +for i in "${!HOSTS[@]}"; do + host="${HOSTS[$i]}" + conda_env="${CONDA_ENVS[$i]}" + log "Starting vllm server (${STAGES[$i]}) on ${HOSTS[$i]}:${PORT}..." + start_vllm_server_on_host "$host" "$conda_env" "${GPUS[$i]}" "${STAGES[$i]}" "$PORT" "${LOGS[$i]}" +done + +# Wait for vllm endpoints to become ready on hosts +for i in "${!HOSTS[@]}"; do + wait_for_endpoint "${HOSTS[$i]}" "$PORT" +done + +log "All vllm endpoints are ready!" + +# Prepare JSON data +DATA=$(jq -n \ + --arg model "$MODEL" \ + --arg prompt "$PROMPT" \ + '{model: $model, prompt: $prompt}') + +log "Sending request to prefill and decode..." + +# Send requests to hosts +prefill_output=$(curl -s "http://${PREFILL_HOST}:${PORT}/v1/completions" \ + -H "Content-Type: application/json" \ + -d "$DATA") + +decode_output=$(curl -s "http://${DECODE_HOST}:${PORT}/v1/completions" \ + -H "Content-Type: application/json" \ + -d "$DATA") + +# Display outputs +printf "Prefill output:\n%s\n\nDecode output:\n%s\n" "$prefill_output" "$decode_output" \ No newline at end of file diff --git a/vllm/distributed/kv_transfer/__init__.py b/vllm/distributed/kv_transfer/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/vllm/distributed/kv_transfer/base.py b/vllm/distributed/kv_transfer/base.py new file mode 100644 index 0000000000000..42efdbf02597d --- /dev/null +++ b/vllm/distributed/kv_transfer/base.py @@ -0,0 +1,83 @@ +from abc import ABC, abstractmethod +import torch + +from vllm.attention import AttentionMetadata + + +class KVCacheTransporterBase(ABC): + + @abstractmethod + def save_kv_cache( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + layer_idx: int, + kv_cache: torch.Tensor, + ): + """ + Save the key-value cache for a specific layer. + + Args: + input_ids (torch.Tensor): The input token IDs. + attn_metadata (AttentionMetadata): Metadata related to attention. + layer_idx (int): The index of the layer. + kv_cache (torch.Tensor): The key-value cache tensor. + """ + raise NotImplementedError + + @abstractmethod + def read_kv_cache( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + layer_idx: int, + kv_cache: torch.Tensor, + ): + """ + Read the key-value cache. + + Args: + input_ids (torch.Tensor): The input token IDs. + attn_metadata (AttentionMetadata): Metadata related to attention. + kv_cache (torch.Tensor): The key-value cache tensor to be populated. + """ + raise NotImplementedError + + @abstractmethod + def save_hidden_states( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + hidden_states: torch.Tensor, + ): + """ + Save the hidden states. + + Args: + input_ids (torch.Tensor): The input token IDs. + attn_metadata (AttentionMetadata): Metadata related to attention. + hidden_states (torch.Tensor): The hidden states tensor. + """ + raise NotImplementedError + + @abstractmethod + def read_hidden_states( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + hidden_states: torch.Tensor, + ): + """ + read the hidden states. + + Args: + input_ids (torch.Tensor): The input token IDs. + attn_metadata (AttentionMetadata): Metadata related to attention. + hidden_states (torch.Tensor): The hidden states tensor. + """ + raise NotImplementedError + + @abstractmethod + def synchronize(self): + """Synchronize any asynchronous operations.""" + raise NotImplementedError diff --git a/vllm/distributed/kv_transfer/infinite.py b/vllm/distributed/kv_transfer/infinite.py new file mode 100644 index 0000000000000..6a9c1f749b7e6 --- /dev/null +++ b/vllm/distributed/kv_transfer/infinite.py @@ -0,0 +1,237 @@ +import math +import hashlib +import logging +from typing import Dict, List, Tuple +import torch +import os + +import infinistore + +from vllm.attention import AttentionMetadata +from vllm.distributed.kv_transfer.base import KVCacheTransporterBase + +logger = logging.getLogger(__name__) + +Default_Infinite_Server = "127.0.0.1" + +class InfiniStoreKVCacheTransporter(KVCacheTransporterBase): + + def __init__(self, model: str, tokens_per_page=16) -> None: + if not model: + raise ValueError("model cannot be empty.") + if tokens_per_page <= 0: + raise ValueError("tokens_per_page must be greater than 0.") + + self.model = model + self.tokens_per_page = tokens_per_page + + infinite_server = os.environ.get("INFINITE_STORE_SERVER", Default_Infinite_Server) + infinite_server = infinite_server.strip('"') + infinte_config = infinistore.ClientConfig( + host_addr=infinite_server, + service_port=22345, + log_level="warning", + connection_type=infinistore.TYPE_RDMA, + ) + + self.conn = infinistore.InfinityConnection(infinte_config) + + logger.info("connecting to infinite store server: ", infinite_server) + + self.conn.connect() + + def _compute_kv_cache_block_offsets( + self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata, + seq_index: int, seq_length: int, layer_idx: int, + kv_cache: torch.Tensor) -> Tuple[List[Tuple[str, int]], int]: + + seq_tokens = input_ids[seq_index:seq_index + seq_length].cpu().numpy() + num_pages = math.ceil(seq_length / self.tokens_per_page) + block_offsets: List[Tuple[str, int]] = [] + prev_hash = "" + page_size = kv_cache[0][0].numel() # Number of elements in one page + k_or_v_cache_size = kv_cache[0].numel( + ) # Size of key or value cache per token + + for page_num in range(num_pages): + # Calculate token indices for the current page + start_token = page_num * self.tokens_per_page + end_token = min((page_num + 1) * self.tokens_per_page, seq_length) + tokens_in_page = seq_tokens[start_token:end_token] + + # Compute the hash for the current page + tokens_bytes = tokens_in_page.tobytes() + hash_input = prev_hash.encode('utf-8') + tokens_bytes + current_hash = hashlib.sha256(hash_input).hexdigest() + + # Generate cache keys using the current hash + k_cache_key = f"{self.model}_{current_hash}_layer_{layer_idx}_k" + v_cache_key = f"{self.model}_{current_hash}_layer_{layer_idx}_v" + + # Calculate the offset in the kv_cache for the current page + try: + slot_index = page_num * self.tokens_per_page + slot_mapping_value = attn_metadata.slot_mapping[ + seq_index + slot_index].item() + page_offset = (slot_mapping_value // + self.tokens_per_page) * page_size + except IndexError as e: + logger.error("Invalid slot mapping index %s: %s", slot_index, + e) + raise + + block_offsets.append((k_cache_key, page_offset)) + block_offsets.append( + (v_cache_key, page_offset + k_or_v_cache_size)) + + # Update the previous hash for the next page + prev_hash = current_hash + + logger.debug( + "Computed kv_cache block offsets: layer %s, page %s, " + "k_cache_key %s, v_cache_key %s", layer_idx, page_num, + k_cache_key, v_cache_key) + + return block_offsets, page_size + + def _compute_hidden_states_block_offsets( + self, input_ids: torch.Tensor, attn_metadata: AttentionMetadata, + seq_index: int, seq_length: int, + hidden_states: torch.Tensor) -> Dict[int, List[Tuple[str, int]]]: + + seq_tokens = input_ids[seq_index:seq_index + seq_length].cpu().numpy() + num_pages = math.ceil(seq_length / self.tokens_per_page) + block_offsets: Dict[int, List[Tuple[str, int]]] = {} + prev_hash = "" + hidden_size = hidden_states.size(-1) + + for page_num in range(num_pages): + # Calculate token indices for the current page + start_token = page_num * self.tokens_per_page + end_token = min((page_num + 1) * self.tokens_per_page, seq_length) + tokens_in_page = seq_tokens[start_token:end_token] + + # Compute the hash for the current page + tokens_bytes = tokens_in_page.tobytes() + hash_input = prev_hash.encode('utf-8') + tokens_bytes + current_hash = hashlib.sha256(hash_input).hexdigest() + + # Generate cache key using the current hash + cache_key = f"{self.model}_{current_hash}_hidden_states" + + # Calculate cache size and offset + cache_size = hidden_size * (end_token - start_token) + offset = (seq_index + start_token) * hidden_size + + if cache_size not in block_offsets: + block_offsets[cache_size] = [] + block_offsets[cache_size].append((cache_key, offset)) + + # Update the previous hash for the next page + prev_hash = current_hash + + logger.debug( + "Computed hidden_states block offsets: page %s, cache_key %s", + page_num, cache_key) + + return block_offsets + + def save_kv_cache(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, layer_idx: int, + kv_cache: torch.Tensor) -> None: + + seq_index = 0 + + for seq_length_tensor in attn_metadata.seq_lens_tensor: + seq_length = seq_length_tensor.item() + block_offsets, page_size = self._compute_kv_cache_block_offsets( + input_ids, attn_metadata, seq_index, seq_length, layer_idx, + kv_cache) + + # Write to cache + try: + self.conn.write_cache(kv_cache, block_offsets, page_size) + except Exception as e: + logger.error("Failed to write kv_cache: %s", e) + raise + + seq_index += seq_length + + logger.debug("Saved kv_cache for layer %s", layer_idx) + + def read_kv_cache(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, layer_idx: int, + kv_cache: torch.Tensor) -> None: + + seq_index = 0 + + for seq_length_tensor in attn_metadata.seq_lens_tensor: + seq_length = seq_length_tensor.item() + block_offsets, page_size = self._compute_kv_cache_block_offsets( + input_ids, attn_metadata, seq_index, seq_length, layer_idx, + kv_cache) + + # Read from cache + try: + self.conn.read_cache(kv_cache, block_offsets, page_size) + except Exception as e: + logger.error("Failed to read kv_cache: %s", e) + raise + + seq_index += seq_length + + logger.debug("Loaded kv_cache for layer %s", layer_idx) + + def save_hidden_states(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + hidden_states: torch.Tensor) -> None: + + seq_index = 0 + + for seq_length_tensor in attn_metadata.seq_lens_tensor: + seq_length = seq_length_tensor.item() + block_offsets = self._compute_hidden_states_block_offsets( + input_ids, attn_metadata, seq_index, seq_length, hidden_states) + + # Write to cache + try: + for cache_size, offsets in block_offsets.items(): + self.conn.write_cache(hidden_states, offsets, cache_size) + except Exception as e: + logger.error("Failed to write hidden_states: %s", e) + raise + + seq_index += seq_length + + logger.debug("Saved hidden_states") + + def read_hidden_states(self, input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + hidden_states: torch.Tensor) -> None: + + seq_index = 0 + + for seq_length_tensor in attn_metadata.seq_lens_tensor: + seq_length = seq_length_tensor.item() + block_offsets = self._compute_hidden_states_block_offsets( + input_ids, attn_metadata, seq_index, seq_length, hidden_states) + + # Read from cache + try: + for cache_size, offsets in block_offsets.items(): + self.conn.read_cache(hidden_states, offsets, cache_size) + except Exception as e: + logger.error("Failed to read hidden_states: %s", e) + raise + + seq_index += seq_length + + logger.debug("Loaded hidden_states") + + def synchronize(self) -> None: + try: + self.conn.sync() + logger.debug("Synchronized with Infinity service") + except Exception as e: + logger.error("Failed to synchronize: %s", e) + raise diff --git a/vllm/distributed/kv_transfer/utils.py b/vllm/distributed/kv_transfer/utils.py new file mode 100644 index 0000000000000..e5410db4e18be --- /dev/null +++ b/vllm/distributed/kv_transfer/utils.py @@ -0,0 +1,32 @@ +import os +import torch + +from vllm.attention import AttentionMetadata + + +def _get_pd_sep_stage(): + return os.environ.get("PD_SEPARATE_STAGE", "").lower() + + +def _is_profile_run(input_ids: torch.Tensor): + # profile_run will send in an all-zero input_ids tensor + return torch.any(input_ids == 0).item() + + +def is_first_decode_pass(input_ids: torch.tensor, + attn_metadata: AttentionMetadata): + if _get_pd_sep_stage() != "decode": + return False + + if _is_profile_run(input_ids): + return False + + return (attn_metadata.prefill_metadata is not None + and attn_metadata.decode_metadata is None) + + +def is_prefill_run(input_ids: torch.Tensor): + if _get_pd_sep_stage() != "prefill": + return False + + return not _is_profile_run(input_ids) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 3f17e9004c30f..616aba16e3006 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -32,6 +32,8 @@ from vllm.config import CacheConfig, LoRAConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) +from vllm.distributed.kv_transfer.utils import (is_first_decode_pass, + is_prefill_run) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, @@ -322,7 +324,18 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: + + first_decode_pass = is_first_decode_pass(input_ids, attn_metadata) + prefill_pass = is_prefill_run(input_ids) + + if first_decode_pass or prefill_pass: + if 'kv_cache_transporter' not in kwargs: + raise ValueError( + "Missing 'kv_cache_transporter' in keyword arguments.") + kv_cache_transporter = kwargs['kv_cache_transporter'] + if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -334,12 +347,29 @@ def forward( hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] + if first_decode_pass: + for i, kv_cache in enumerate(kv_caches): + kv_cache_transporter.read_kv_cache(input_ids, attn_metadata, i, + kv_cache) + + kv_cache_transporter.read_hidden_states(input_ids, attn_metadata, + hidden_states) + + kv_cache_transporter.synchronize() + + return hidden_states + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, kv_caches[i - self.start_layer], attn_metadata, residual) + if prefill_pass: + kv_cache_transporter.save_kv_cache( + input_ids, attn_metadata, i, + kv_caches[i - self.start_layer]) + if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -347,6 +377,12 @@ def forward( }) hidden_states, _ = self.norm(hidden_states, residual) + + if prefill_pass: + kv_cache_transporter.save_hidden_states(input_ids, attn_metadata, + hidden_states) + kv_cache_transporter.synchronize() + return hidden_states def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): @@ -546,9 +582,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: model_output = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + **kwargs) return model_output def compute_logits( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 3bcdb0d87fd52..cfb941db6b1ab 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -25,6 +25,8 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig +from vllm.distributed.kv_transfer.utils import (is_first_decode_pass, + is_prefill_run) from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -252,7 +254,17 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: + first_decode_pass = is_first_decode_pass(input_ids, attn_metadata) + prefill_pass = is_prefill_run(input_ids) + + if first_decode_pass or prefill_pass: + if 'kv_cache_transporter' not in kwargs: + raise ValueError( + "Missing 'kv_cache_transporter' in keyword arguments.") + kv_cache_transporter = kwargs['kv_cache_transporter'] + if get_pp_group().is_first_rank: if inputs_embeds is None: inputs_embeds = self.get_input_embeddings(input_ids) @@ -264,11 +276,28 @@ def forward( assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] + if first_decode_pass: + for i, kv_cache in enumerate(kv_caches): + kv_cache_transporter.read_kv_cache(input_ids, attn_metadata, i, + kv_cache) + + kv_cache_transporter.read_hidden_states(input_ids, attn_metadata, + hidden_states) + + kv_cache_transporter.synchronize() + + return hidden_states + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer(hidden_states, kv_caches[i - self.start_layer], attn_metadata) + + if prefill_pass: + kv_cache_transporter.save_kv_cache( + input_ids, attn_metadata, i, + kv_caches[i - self.start_layer]) if not get_pp_group().is_last_rank: return IntermediateTensors({"hidden_states": hidden_states}) @@ -276,6 +305,12 @@ def forward( hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: hidden_states, _ = self.project_out(hidden_states) + + if prefill_pass: + kv_cache_transporter.save_hidden_states(input_ids, attn_metadata, + hidden_states) + kv_cache_transporter.synchronize() + return hidden_states @@ -304,13 +339,15 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: return self.decoder(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, - inputs_embeds=inputs_embeds) + inputs_embeds=inputs_embeds, + **kwargs) class OPTForCausalLM(nn.Module, SupportsPP): @@ -355,9 +392,11 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, + **kwargs, ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata, intermediate_tensors) + attn_metadata, intermediate_tensors, + **kwargs) return hidden_states def compute_logits( diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index 95345df43b57d..b526f7513d955 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -1,5 +1,6 @@ """Sampling parameters for text generation.""" import copy +import os from dataclasses import dataclass from enum import Enum, IntEnum from functools import cached_property @@ -389,6 +390,13 @@ def _verify_args(self) -> None: RequestOutputKind.DELTA): raise ValueError("best_of must equal n to use output_kind=DELTA") + if os.environ.get("PD_SEPARATE_STAGE", "").lower() == "prefill": + if self.max_tokens is None or self.max_tokens != 1: + logger.warning("Prefill run only generates one token. " + "max_tokens is set to 1.") + + self.max_tokens = 1 + def _verify_greedy_sampling(self) -> None: if self.n > 1: raise ValueError("n must be 1 when using greedy sampling, " diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5bc7100732291..42e45b5f6243b 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -25,6 +25,10 @@ PromptAdapterConfig, SchedulerConfig) from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group +from vllm.distributed.kv_transfer.base import KVCacheTransporterBase +from vllm.distributed.kv_transfer.infinite import InfiniStoreKVCacheTransporter +from vllm.distributed.kv_transfer.utils import (is_first_decode_pass, + is_prefill_run) from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry @@ -1601,6 +1605,15 @@ def prepare_model_input( is_prompt=is_prompt, virtual_engine=virtual_engine) + def get_kv_cache_transporter( + self, input_ids, + attn_metadata) -> Optional[KVCacheTransporterBase]: + if is_prefill_run(input_ids) or is_first_decode_pass( + input_ids, attn_metadata): + return InfiniStoreKVCacheTransporter(self.model_config.model) + + return None + @torch.inference_mode() @dump_input_when_exception(exclude_args=[0], exclude_kwargs=["self"]) def execute_model( @@ -1662,8 +1675,10 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + device=self.device), + **seqlen_agnostic_kwargs, + kv_cache_transporter=self.get_kv_cache_transporter( + model_input.input_tokens, model_input.attn_metadata)) if (self.observability_config is not None and self.observability_config.collect_model_forward_time):