Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add fix for rocm, tpu, and xpu #8

Merged
merged 1 commit into from
Aug 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def prefill_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor=self.num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=self.max_query_len,
max_prefill_seq_len=self.max_prefill_seq_len,
max_decode_seq_len=0,
Expand Down Expand Up @@ -164,6 +165,7 @@ def decode_metadata(self) -> Optional["ROCmFlashAttentionMetadata"]:
slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
num_orig_input_tokens_tensor=self.num_orig_input_tokens_tensor[:self.num_prefills],
max_query_len=None,
max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len,
Expand Down
26 changes: 26 additions & 0 deletions vllm/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@ def _dummy_run(
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
num_orig_input_tokens_tensor = torch.full((batch_size, seq_len), seq_len,
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
Expand All @@ -173,6 +176,7 @@ def _dummy_run(
num_prefill_tokens=batch_size * seq_len,
num_decode_tokens=0,
slot_mapping=slot_mapping,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
block_tables=None,
context_lens=None,
)
Expand All @@ -187,6 +191,9 @@ def _dummy_run(
position_ids = torch.zeros((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
num_orig_input_tokens_tensor = torch.ones((batch_size, seq_len),
dtype=torch.int32,
device=self.device)
slot_mapping = torch.zeros((batch_size, seq_len),
dtype=torch.int64,
device=self.device)
Expand All @@ -205,6 +212,7 @@ def _dummy_run(
num_prefill_tokens=0,
num_decode_tokens=batch_size * seq_len,
slot_mapping=slot_mapping,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
block_tables=block_tables,
context_lens=context_lens,
)
Expand Down Expand Up @@ -285,6 +293,8 @@ def _prepare_prompt(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
# The number of original input tokens of each sequence
num_orig_input_tokens_list: List[int] = []
prompt_lens: List[int] = []
slot_mapping: List[int] = []

Expand All @@ -302,6 +312,7 @@ def _prepare_prompt(

input_tokens.extend(prompt_tokens)
input_positions.extend(list(range(prompt_len)))
num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * prompt_len)

assert seq_group_metadata.block_tables is not None
block_table = seq_group_metadata.block_tables[seq_id]
Expand All @@ -321,6 +332,7 @@ def _prepare_prompt(
num_paddings = padded_prompt_len - prompt_len
input_tokens += [0] * num_paddings
input_positions += [0] * num_paddings
num_orig_input_tokens_list += [0] * num_paddings
slot_mapping += [_PAD_SLOT_ID] * num_paddings

assert len(prompt_lens) > 0
Expand All @@ -331,6 +343,9 @@ def _prepare_prompt(
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=self.device) # type: ignore
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
Expand All @@ -342,6 +357,7 @@ def _prepare_prompt(
num_prefill_tokens=0, # NOTE: This is not used.
num_decode_tokens=0,
slot_mapping=slot_mapping,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
block_tables=None,
context_lens=None,
)
Expand All @@ -354,6 +370,8 @@ def _prepare_decode(
assert len(seq_group_metadata_list) > 0
input_tokens: List[List[int]] = []
input_positions: List[List[int]] = []
# The number of original input tokens of each sequence
num_orig_input_tokens_list: List[List[int]] = []
slot_mapping: List[List[int]] = []
context_lens: List[int] = []

Expand All @@ -369,6 +387,8 @@ def _prepare_decode(
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append([position])
num_orig_input_tokens_list.append([seq_data.get_prompt_len()])

context_lens.append(seq_len)

assert seq_group_metadata.block_tables is not None
Expand All @@ -385,6 +405,8 @@ def _prepare_decode(
num_paddings = batch_size - batch_idx
input_tokens = input_tokens + [[0]] * num_paddings
input_positions = input_positions + [[0]] * num_paddings
num_orig_input_tokens_list = num_orig_input_tokens_list + [[0]] * num_paddings

slot_mapping = slot_mapping + [[_PAD_SLOT_ID]] * num_paddings
context_lens = context_lens + [0] * num_paddings

Expand All @@ -394,6 +416,9 @@ def _prepare_decode(
input_positions = torch.tensor(input_positions,
dtype=torch.int32,
device="cpu")
num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device="cpu")
slot_mapping = torch.tensor(slot_mapping,
dtype=torch.int64,
device="cpu")
Expand All @@ -413,6 +438,7 @@ def _prepare_decode(
slot_mapping=slot_mapping,
block_tables=block_tables,
context_lens=context_lens,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
)
return input_tokens, input_positions, attn_metadata, input_lens

Expand Down
17 changes: 17 additions & 0 deletions vllm/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,8 @@ def _prepare_prompt(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
# The number of original input tokens of each sequence
num_orig_input_tokens_list: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
multi_modal_inputs_list: List[MultiModalInputs] = []
Expand All @@ -178,6 +180,8 @@ def _prepare_prompt(
# is always the first token in the sequence.
input_positions.extend(list(range(computed_len, seq_len)))

num_orig_input_tokens_list.extend([seq_data.get_prompt_len()] * (seq_len - computed_len))

if seq_group_metadata.block_tables is None:
# During memory profiling, the block tables are not initialized
# yet. In this case, we just use a dummy slot mapping.
Expand Down Expand Up @@ -214,6 +218,11 @@ def _prepare_prompt(
input_positions = torch.tensor(input_positions,
dtype=torch.long,
device=self.device) # type: ignore

num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=self.device) # type: ignore

slot_mapping = torch.tensor(slot_mapping,
dtype=torch.long,
device=self.device) # type: ignore
Expand All @@ -236,6 +245,7 @@ def _prepare_prompt(
num_prefill_tokens=num_prompt_tokens,
num_decode_tokens=0,
block_tables=torch.tensor([], device=self.device, dtype=torch.int),
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
)

multi_modal_kwargs = MultiModalInputs.batch(multi_modal_inputs_list)
Expand All @@ -250,6 +260,8 @@ def _prepare_decode(
assert len(seq_group_metadata_list) > 0
input_tokens: List[int] = []
input_positions: List[int] = []
# The number of original input tokens of each sequence
num_orig_input_tokens_list: List[int] = []
slot_mapping: List[int] = []
seq_lens: List[int] = []
block_tables: List[List[int]] = []
Expand All @@ -268,6 +280,7 @@ def _prepare_decode(
seq_len = seq_data.get_len()
position = seq_len - 1
input_positions.append(position)
num_orig_input_tokens_list.append(seq_data.get_prompt_len())

seq_len = seq_len if self.sliding_window is None else min(
seq_len, self.sliding_window)
Expand Down Expand Up @@ -299,6 +312,9 @@ def _prepare_decode(
seq_lens_tensor = torch.tensor(seq_lens,
dtype=torch.int,
device=self.device)
num_orig_input_tokens_tensor = torch.tensor(num_orig_input_tokens_list,
dtype=torch.long,
device=self.device)

block_tables = make_tensor_with_pad(
block_tables,
Expand All @@ -319,6 +335,7 @@ def _prepare_decode(
num_decode_tokens=len(input_tokens),
num_prefills=0,
block_tables=block_tables,
num_orig_input_tokens_tensor=num_orig_input_tokens_tensor,
)
return (
input_tokens,
Expand Down
Loading