Skip to content

Commit

Permalink
[Bugfix] Bandaid fix for speculative decoding tests (vllm-project#9327)
Browse files Browse the repository at this point in the history
Signed-off-by: Amit Garg <[email protected]>
  • Loading branch information
tlrmchlsmth authored and garg-amit committed Oct 28, 2024
1 parent d6d3698 commit 5e1e462
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
Expand Down Expand Up @@ -1028,16 +1029,30 @@ def __init__(
self.graph_block_tables = np.zeros(
(self.max_batchsize_to_capture, self.get_max_block_per_batch()),
dtype=np.int32)

# Attention-free but stateful models like Mamba need a placeholder attn
# backend, as the attention metadata is needed to manage internal state.
# However we must bypass attention selection altogether for some models
# used for speculative decoding to avoid a divide-by-zero in
# model_config.get_head_size()
num_attn_heads = self.model_config.get_num_attention_heads(
self.parallel_config)
needs_attn_backend = (num_attn_heads != 0
or self.model_config.is_attention_free)

self.attn_backend = get_attn_backend(
self.model_config.get_head_size(),
self.model_config.get_sliding_window(),
self.model_config.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
)
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
) if needs_attn_backend else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
weakref.proxy(self))
else:
self.attn_state = CommonAttentionState(weakref.proxy(self))

# Multi-modal data support
self.input_registry = input_registry
Expand Down

0 comments on commit 5e1e462

Please sign in to comment.