diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index b0f4d0530b7f0..66c865c272196 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -136,6 +136,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor=self.num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -164,6 +165,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], + num_orig_input_tokens_tensor=self.num_orig_input_tokens_tensor[:self.num_prefills], max_query_len=None, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index 01daa64b5a32f..0438f79b4e0ea 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -165,6 +165,9 @@ def _dummy_run( position_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, device=self.device) + num_orig_input_tokens_tensor = torch.full((batch_size, seq_len), seq_len, + dtype=torch.int32, + device=self.device) slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) @@ -173,6 +176,7 @@ def _dummy_run( num_prefill_tokens=batch_size * seq_len, num_decode_tokens=0, slot_mapping=slot_mapping, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, block_tables=None, context_lens=None, ) @@ -187,6 +191,9 @@ def _dummy_run( position_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, device=self.device) + num_orig_input_tokens_tensor = torch.ones((batch_size, seq_len), + dtype=torch.int32, + device=self.device) slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) @@ -205,6 +212,7 @@ def _dummy_run( num_prefill_tokens=0, num_decode_tokens=batch_size * seq_len, slot_mapping=slot_mapping, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, block_tables=block_tables, context_lens=context_lens, ) @@ -285,6 +293,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] prompt_lens: List[int] = [] slot_mapping: List[int] = [] @@ -302,6 +312,7 @@ def _prepare_prompt( input_tokens.extend(prompt_tokens) input_positions.extend(list(range(prompt_len))) + num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * prompt_len) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] @@ -321,6 +332,7 @@ def _prepare_prompt( num_paddings = padded_prompt_len - prompt_len input_tokens += [0] * num_paddings input_positions += [0] * num_paddings + num_orig_input_tokens_list += [0] * num_paddings slot_mapping += [_PAD_SLOT_ID] * num_paddings assert len(prompt_lens) > 0 @@ -331,6 +343,9 @@ def _prepare_prompt( input_positions = torch.tensor(input_positions, dtype=torch.int32, device="cpu") + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=self.device) # type: ignore slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, device="cpu") @@ -342,6 +357,7 @@ def _prepare_prompt( num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, block_tables=None, context_lens=None, ) @@ -354,6 +370,8 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[List[int]] = [] input_positions: List[List[int]] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[List[int]] = [] slot_mapping: List[List[int]] = [] context_lens: List[int] = [] @@ -369,6 +387,8 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append([position]) + num_orig_input_tokens_list.append([seq_data.get_prompt_len()]) + context_lens.append(seq_len) assert seq_group_metadata.block_tables is not None @@ -385,6 +405,8 @@ def _prepare_decode( num_paddings = batch_size - batch_idx input_tokens = input_tokens + [[0]] * num_paddings input_positions = input_positions + [[0]] * num_paddings + num_orig_input_tokens_list = num_orig_input_tokens_list + [[0]] * num_paddings + slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings context_lens = context_lens + [0] * num_paddings @@ -394,6 +416,9 @@ def _prepare_decode( input_positions = torch.tensor(input_positions, dtype=torch.int32, device="cpu") + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device="cpu") slot_mapping = torch.tensor(slot_mapping, dtype=torch.int64, device="cpu") @@ -413,6 +438,7 @@ def _prepare_decode( slot_mapping=slot_mapping, block_tables=block_tables, context_lens=context_lens, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) return input_tokens, input_positions, attn_metadata, input_lens diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 0335bbcd091e8..2efe0c53598e5 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -155,6 +155,8 @@ def _prepare_prompt( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] multi_modal_inputs_list: List[MultiModalInputs] = [] @@ -178,6 +180,8 @@ def _prepare_prompt( # is always the first token in the sequence. input_positions.extend(list(range(computed_len, seq_len))) + num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * (seq_len - computed_len)) + if seq_group_metadata.block_tables is None: # During memory profiling, the block tables are not initialized # yet. In this case, we just use a dummy slot mapping. @@ -214,6 +218,11 @@ def _prepare_prompt( input_positions = torch.tensor(input_positions, dtype=torch.long, device=self.device) # type: ignore + + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=self.device) # type: ignore + slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=self.device) # type: ignore @@ -236,6 +245,7 @@ def _prepare_prompt( num_prefill_tokens=num_prompt_tokens, num_decode_tokens=0, block_tables=torch.tensor([], device=self.device, dtype=torch.int), + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list) @@ -250,6 +260,8 @@ def _prepare_decode( assert len(seq_group_metadata_list) > 0 input_tokens: List[int] = [] input_positions: List[int] = [] + # The number of original input tokens of each sequence + num_orig_input_tokens_list: List[int] = [] slot_mapping: List[int] = [] seq_lens: List[int] = [] block_tables: List[List[int]] = [] @@ -268,6 +280,7 @@ def _prepare_decode( seq_len = seq_data.get_len() position = seq_len - 1 input_positions.append(position) + num_orig_input_tokens_list.append(seq_data.get_prompt_len()) seq_len = seq_len if self.sliding_window is None else min( seq_len, self.sliding_window) @@ -299,6 +312,9 @@ def _prepare_decode( seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int, device=self.device) + num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list, + dtype=torch.long, + device=self.device) block_tables = make_tensor_with_pad( block_tables, @@ -319,6 +335,7 @@ def _prepare_decode( num_decode_tokens=len(input_tokens), num_prefills=0, block_tables=block_tables, + num_orig_input_tokens_tensor=num_orig_input_tokens_tensor, ) return ( input_tokens,