From a000e1b8467381d0acef0f32ed2836a4c7a05d47 Mon Sep 17 00:00:00 2001 From: Jianwen Zhang Date: Thu, 29 Aug 2024 05:45:11 +0000 Subject: [PATCH] fix blocksparse attn backend and phi3-small models --- vllm/attention/backends/blocksparse_attn.py | 4 ++++ vllm/model_executor/models/phi3_small.py | 5 ++++- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index d84a40890ebbd..dc92d4a7351e0 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -220,6 +220,8 @@ def prefill_metadata( query_start_loc=self.query_start_loc[:self.num_prefills + 1], seq_start_loc=self.seq_start_loc[:self.num_prefills + 1], context_lens_tensor=self.context_lens_tensor[:self.num_prefills], + num_orig_input_tokens_tensor = (None if self.num_orig_input_tokens_tensor is None else + self.num_orig_input_tokens_tensor[:self.num_prefills]), block_tables=self.block_tables[:self.num_prefills], use_cuda_graph=False, ) @@ -248,6 +250,8 @@ def decode_metadata(self) -> Optional["BlocksparseFlashAttentionMetadata"]: query_start_loc=None, seq_start_loc=None, context_lens_tensor=None, + num_orig_input_tokens_tensor = (None if self.num_orig_input_tokens_tensor is None else + self.num_orig_input_tokens_tensor[:self.num_prefills]), block_tables=self.block_tables[self.num_prefills:], use_cuda_graph=self.use_cuda_graph, ) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index df01bfa3d8e6e..3a2336999b0cd 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -240,7 +240,10 @@ def forward( k = k.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) v = v.reshape(-1, self.head_dim * self.num_kv_heads_per_partion) - q, k = self.rotary_emb(positions, q, k) + q, k = self.rotary_emb(positions, q, k) \ + if getattr(self.config, "rope_scaling", None) is None \ + else self.rotary_emb(positions, q, k, num_orig_input_tokens_tensor=attn_metadata.num_orig_input_tokens_tensor) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata=attn_metadata) output, _ = self.dense(attn_output)