diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index de88f1e9ac14f..6e86b803683b0 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -17,6 +17,7 @@ import vllm.envs as envs from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState +from vllm.attention.backends.utils import CommonAttentionState from vllm.compilation.compile_context import set_compile_context from vllm.compilation.levels import CompilationLevel from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, @@ -1028,6 +1029,17 @@ def __init__( self.graph_block_tables = np.zeros( (self.max_batchsize_to_capture, self.get_max_block_per_batch()), dtype=np.int32) + + # Attention-free but stateful models like Mamba need a placeholder attn + # backend, as the attention metadata is needed to manage internal state. + # However we must bypass attention selection altogether for some models + # used for speculative decoding to avoid a divide-by-zero in + # model_config.get_head_size() + num_attn_heads = self.model_config.get_num_attention_heads( + self.parallel_config) + needs_attn_backend = (num_attn_heads != 0 + or self.model_config.is_attention_free) + self.attn_backend = get_attn_backend( self.model_config.get_head_size(), self.model_config.get_sliding_window(), @@ -1035,9 +1047,12 @@ def __init__( self.kv_cache_dtype, self.block_size, self.model_config.is_attention_free, - ) - self.attn_state = self.attn_backend.get_state_cls()( - weakref.proxy(self)) + ) if needs_attn_backend else None + if self.attn_backend: + self.attn_state = self.attn_backend.get_state_cls()( + weakref.proxy(self)) + else: + self.attn_state = CommonAttentionState(weakref.proxy(self)) # Multi-modal data support self.input_registry = input_registry