Skip to content

Commit

Permalink
Merge pull request #6 from jiazhan-msft/dev/jiazhan/fix_blocksparseat…
Browse files Browse the repository at this point in the history
…tn_phi3small

fix blocksparse attn backend and phi3-small models
  • Loading branch information
jiazhan-msft authored Aug 29, 2024
2 parents 7577906 + a000e1b commit 9c0d2c2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
4 changes: 4 additions & 0 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ def prefill_metadata(
query_start_loc=self.query_start_loc[:self.num_prefills + 1],
seq_start_loc=self.seq_start_loc[:self.num_prefills + 1],
context_lens_tensor=self.context_lens_tensor[:self.num_prefills],
num_orig_input_tokens_tensor = (None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills]),
block_tables=self.block_tables[:self.num_prefills],
use_cuda_graph=False,
)
Expand Down Expand Up @@ -248,6 +250,8 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]:
query_start_loc=None,
seq_start_loc=None,
context_lens_tensor=None,
num_orig_input_tokens_tensor = (None if self.num_orig_input_tokens_tensor is None else
self.num_orig_input_tokens_tensor[:self.num_prefills]),
block_tables=self.block_tables[self.num_prefills:],
use_cuda_graph=self.use_cuda_graph,
)
Expand Down
5 changes: 4 additions & 1 deletion vllm/model_executor/models/phi3_small.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,10 @@ def forward(
k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)
v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion)

q, k = self.rotary_emb(positions, q, k)
q, k = self.rotary_emb(positions, q, k) \
if getattr(self.config, "rope_scaling", None) is None \
else self.rotary_emb(positions, q, k, num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor)

attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata)
output, _ = self.dense(attn_output)

Expand Down

0 comments on commit 9c0d2c2

Please sign in to comment.