diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 43853063cfb40..e022f7481ee51 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -60,10 +60,14 @@ LORA_WARMUP_RANK = 8 _BATCH_SIZE_ALIGNMENT = 8 -# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256. +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. # NOTE: _get_graph_batch_size needs to be updated if this list is changed. _BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 33) + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) ] _NUM_WARMUP_ITERS = 2 @@ -660,7 +664,7 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def _use_captured_graph(self, batch_size: int, max_decode_seq_len: int) -> bool: return (self.decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] + and batch_size <= self.runner.max_batchsize_to_capture and max_decode_seq_len <= self.runner.max_seq_len_to_capture) def build(self) -> ModelInputForGPU: @@ -846,6 +850,8 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture + self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ {} for _ in range(self.parallel_config.pipeline_parallel_size) @@ -863,7 +869,7 @@ def __init__( # The shape of the cached block table will be # (max batch size to capture, max context len to capture / block size). self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()), + (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) num_attn_heads = self.model_config.get_num_attention_heads( self.parallel_config) @@ -1218,7 +1224,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. - max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) + max_batch_size = self.max_batchsize_to_capture input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda() input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda() @@ -1246,8 +1252,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: None ] * self.parallel_config.pipeline_parallel_size - graph_batch_size = _get_graph_batch_size( - self.scheduler_config.max_num_seqs) + graph_batch_size = self.max_batchsize_to_capture batch_size_capture_list = [ bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size ] @@ -1673,3 +1678,22 @@ def _get_graph_batch_size(batch_size: int) -> int: else: return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + +def _get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling _get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. + if not, it means the padded size is larger than the largest size in + _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = _get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1]