Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix long seq bug #5

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class AttentionMetadata:
# in block 0, and 1st slot in block 1, respectively.
slot_mapping: torch.Tensor

# Number of original input tokens (without any decoding).
# Some model (phi3-) need this info to decide model settings
num_orig_input_tokens_tensor: torch.Tensor

@property
@abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]:
Expand Down Expand Up @@ -184,7 +188,8 @@ def __init__(self, input_builder: "ModelRunnerInputBuilderBase") -> None:

@abstractmethod
def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int) -> T:
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int) -> T:
"""Build attention metadata with on-device tensors."""
raise NotImplementedError

Expand Down
6 changes: 6 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ def prefill_metadata(
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=(
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills]),
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
Expand Down Expand Up @@ -248,6 +251,9 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
num_orig_input_tokens_tensor=(
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills]),
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
Expand Down
11 changes: 10 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -291,6 +293,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down Expand Up @@ -427,7 +431,8 @@ def _add_seq_group(
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
Expand Down Expand Up @@ -499,6 +504,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dim=0,
dtype=query_start_loc.dtype,
out=query_start_loc[1:])
num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=device)

return FlashAttentionMetadata(
num_prefills=self.num_prefills,
Expand All @@ -507,6 +515,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
Expand Down
20 changes: 18 additions & 2 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def graph_capture(self, max_batch_size: int):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)
self._num_orig_input_tokens_tensor = torch.zeros(
max_batch_size, dtype=torch.int32, device=self.runner.device)
self._graph_decode_workspace_buffer = self._get_workspace_buffer()
self._graph_indices_buffer = torch.empty(
max_batch_size * self.runner.cache_config.num_gpu_blocks,
Expand All @@ -154,6 +156,7 @@ def graph_capture(self, max_batch_size: int):
del self._graph_indptr_buffer
del self._graph_last_page_len_buffer
del self._graph_decode_wrapper
del self._num_orig_input_tokens_tensor

def graph_clone(self, batch_size: int):
assert self._is_graph_capturing
Expand Down Expand Up @@ -200,6 +203,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
slot_mapping=self._graph_slot_mapping[:batch_size],
num_prefill_tokens=0,
num_decode_tokens=batch_size,
num_orig_input_tokens_tensor=self.
_num_orig_input_tokens_tensor[:batch_size],
max_prefill_seq_len=0,
block_tables=self._graph_block_tables,
paged_kv_indptr=paged_kv_indptr_tensor_host,
Expand All @@ -221,10 +226,15 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):

def get_graph_input_buffers(self, attn_metadata):
return {
"slot_mapping": attn_metadata.slot_mapping,
"slot_mapping":
attn_metadata.slot_mapping,
"num_orig_input_tokens_tensor":
attn_metadata.num_orig_input_tokens_tensor,
}

def prepare_graph_input_buffers(self, input_buffers, attn_metadata):
input_buffers["num_orig_input_tokens_tensor"].copy_(
attn_metadata.num_orig_input_tokens_tensor, non_blocking=True)
return

def begin_forward(self, model_input):
Expand Down Expand Up @@ -494,7 +504,8 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int):
self.paged_kv_last_page_len.append(last_page_len)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
Expand Down Expand Up @@ -564,6 +575,10 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=device)

if len(self.paged_kv_indptr) > 0:
paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices,
device="cpu",
Expand All @@ -585,6 +600,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_prefill_seq_len=max_prefill_seq_len,
block_tables=block_tables,
paged_kv_indptr=paged_kv_indptr_tensor,
Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,8 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -164,6 +166,8 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
num_orig_input_tokens_tensor=self.
num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down
28 changes: 24 additions & 4 deletions vllm/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,8 @@ def _add_seq_group(
self.block_size, inter_data.block_tables)

def build(self, seq_lens: List[int], query_lens: List[int],
cuda_graph_pad_size: int, batch_size: int):
num_orig_input_tokens_list: List[int], cuda_graph_pad_size: int,
batch_size: int):
"""Build attention metadata with on-device tensors.

Args:
Expand Down Expand Up @@ -258,13 +259,18 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])

num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=device)

return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_query_len=max_query_len,
max_prefill_seq_len=max_prefill_seq_len,
max_decode_seq_len=max_decode_seq_len,
Expand Down Expand Up @@ -294,11 +300,16 @@ def graph_capture(self, max_batch_size: int):
device=self.runner.device)
self._graph_block_tables = torch.from_numpy(
self.runner.graph_block_tables).to(device=self.runner.device)

self._num_orig_input_tokens_tensor = torch.zeros(
max_batch_size, dtype=torch.int32, device=self.runner.device)

yield
self._is_graph_capturing = False
del self._graph_slot_mapping
del self._graph_seq_lens
del self._graph_block_tables
del self._num_orig_input_tokens_tensor

def graph_clone(self, batch_size: int) -> "CommonAttentionState":
assert self._is_graph_capturing
Expand All @@ -313,6 +324,8 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
slot_mapping=self._graph_slot_mapping[:batch_size],
seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size],
num_orig_input_tokens_tensor=self.
_num_orig_input_tokens_tensor[:batch_size],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.runner.max_seq_len_to_capture,
Expand All @@ -326,9 +339,14 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):

def get_graph_input_buffers(self, attn_metadata) -> Dict[str, Any]:
return {
"slot_mapping": attn_metadata.slot_mapping,
"seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables": attn_metadata.decode_metadata.block_tables,
"slot_mapping":
attn_metadata.slot_mapping,
"seq_lens_tensor":
attn_metadata.decode_metadata.seq_lens_tensor,
"block_tables":
attn_metadata.decode_metadata.block_tables,
"num_orig_input_tokens_tensor":
attn_metadata.num_orig_input_tokens_tensor,
}

def prepare_graph_input_buffers(self, input_buffers,
Expand All @@ -337,6 +355,8 @@ def prepare_graph_input_buffers(self, input_buffers,
attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
input_buffers["block_tables"].copy_(
attn_metadata.decode_metadata.block_tables, non_blocking=True)
input_buffers["num_orig_input_tokens_tensor"].copy_(
attn_metadata.num_orig_input_tokens_tensor, non_blocking=True)

def begin_forward(self, model_input) -> None:
return
8 changes: 8 additions & 0 deletions vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,9 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
self.seq_lens_tensor[:self.num_prefills])
context_lens_tensor = (None if self.context_lens_tensor is None else
self.context_lens_tensor[:self.num_prefills])
num_orig_input_tokens_tensor = (
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills])
block_tables = (None if self.block_tables is None else
self.block_tables[:self.num_prefills])

Expand All @@ -211,6 +214,7 @@ def prefill_metadata(self) -> Optional["XFormersMetadata"]:
slot_mapping=slot_mapping,
seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -245,6 +249,9 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
self.seq_lens_tensor[self.num_prefills:])
block_tables = (None if self.block_tables is None else
self.block_tables[self.num_prefills:])
num_orig_input_tokens_tensor = (
None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills])

# Construct & cache decode-phase attention metadata structure
self._cached_decode_metadata = XFormersMetadata(
Expand All @@ -253,6 +260,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]:
num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping,
seq_lens_tensor=seq_lens_tensor,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
block_tables=block_tables,
Expand Down
11 changes: 9 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,8 +536,8 @@ def __init__(
1 + math.log(scale) /
math.log(self.original_max_position_embeddings))

short_cache = self._compute_cos_sin_cache(
original_max_position_embeddings, short_factor, short_mscale)
short_cache = self._compute_cos_sin_cache(max_position_embeddings,
short_factor, short_mscale)
short_cache = short_cache.to(dtype)
self.register_buffer("short_cos_sin_cache",
short_cache,
Expand Down Expand Up @@ -582,11 +582,18 @@ def forward(
query: torch.Tensor,
key: torch.Tensor,
offsets: Optional[torch.Tensor] = None,
*,
num_orig_input_tokens_tensor: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
query = query.view(*query.shape[:-1], -1, self.head_size)
key = key.view(*key.shape[:-1], -1, self.head_size)

k = self.original_max_position_embeddings
long_prompt_offset = torch.where(
num_orig_input_tokens_tensor <= k,
torch.zeros_like(num_orig_input_tokens_tensor),
torch.full_like(num_orig_input_tokens_tensor,
self.max_position_embeddings))
long_prompt_offset = (torch.any(positions > k).float() *
torch.full_like(positions, k)).long()
idx = (torch.add(positions, long_prompt_offset)
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"OrionForCausalLM": ("orion", "OrionForCausalLM"),
"PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"),
"PhiForCausalLM": ("phi", "PhiForCausalLM"),
"Phi3ForCausalLM": ("llama", "LlamaForCausalLM"),
"Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"),
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"),
Expand Down
Loading
Loading