Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix Mamba multistep #11071

Merged
merged 1 commit into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 63 additions & 1 deletion vllm/attention/backends/placeholder_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
from vllm.multimodal import MultiModalPlaceholderMap

if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUBuilder
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

# Placeholder attention backend for models like Mamba and embedding models that
# lack attention.
Expand Down Expand Up @@ -186,6 +187,67 @@ def decode_metadata(self) -> Optional["PlaceholderAttentionMetadata"]:
)
return self._cached_decode_metadata

def advance_step(self,
model_input: "ModelInputForGPUWithSamplingMetadata",
sampled_token_ids: Optional[torch.Tensor],
block_size: int,
num_seqs: int,
num_queries: int,
turn_prefills_into_decodes: bool = False):
"""
Update metadata in-place to advance one decode step.
"""
# When using cudagraph, the num_seqs is padded to the next captured
# batch sized, but num_queries tracks the actual number of requests in
# the batch. For --enforce-eager mode, num_seqs == num_queries
if num_seqs != num_queries:
assert num_seqs > num_queries
assert self.use_cuda_graph

assert not turn_prefills_into_decodes, \
("Multi-Step + Chunked-Prefill is not supported for attention-free"
"models. turn_prefills_into_decodes is a "
"Multi-Step + Chunked-Prefill specific parameter.")

assert self.seq_lens is not None
assert self.max_decode_seq_len == max(self.seq_lens)

assert self.num_prefills == 0
assert self.num_prefill_tokens == 0
assert self.num_decode_tokens == num_seqs

assert self.seq_lens is not None
assert len(self.seq_lens) == num_seqs
assert self.seq_lens_tensor is not None
assert self.seq_lens_tensor.shape == (num_seqs, )
assert self.max_query_len == 1
assert self.max_prefill_seq_len == 0

assert self.query_start_loc is not None
assert self.query_start_loc.shape == (num_queries + 1, )
assert self.seq_start_loc is not None
assert self.seq_start_loc.shape == (num_seqs + 1, )

assert self.context_lens_tensor is not None
assert self.context_lens_tensor.shape == (num_queries, )

assert self.block_tables is not None

# Update query lengths. Note that we update only queries and not seqs,
# since tensors may be padded due to captured cuda graph batch size
for i in range(num_queries):
self.seq_lens[i] += 1
self.max_decode_seq_len = max(self.seq_lens)

# Update sequences, masking off entries greater than num_queries
device = self.seq_lens_tensor.device
mask = torch.arange(self.seq_lens_tensor.size(0),
device=device) < num_queries
self.seq_lens_tensor += mask.to(self.seq_lens_tensor.dtype)
if sampled_token_ids is not None:
model_input.input_tokens.masked_scatter_(
mask, sampled_token_ids[:num_queries])


class PlaceholderAttentionMetadataBuilder(
AttentionMetadataBuilder[PlaceholderAttentionMetadata]):
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/multi_step_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@

logger = init_logger(__name__)

MULTI_STEP_ATTENTION_BACKENDS = ["FLASH_ATTN", "ROCM_FLASH", "FLASHINFER"]
MULTI_STEP_ATTENTION_BACKENDS = [
"FLASH_ATTN", "ROCM_FLASH", "FLASHINFER", "NO_ATTENTION"
]
MULTI_STEP_CHUNKED_PREFILL_ATTENTION_BACKENDS = ["FLASH_ATTN"]

def _get_supported_attention_backends(chunked_prefill_enabled: bool) \
Expand Down
Loading